Logo ROOT   6.14/05
Reference Guide
Regularization.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 10/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12  //////////////////////////////////////////////////////////////////////
13  // Implementation of the regularization functions for the reference //
14  // implementation. //
15  //////////////////////////////////////////////////////////////////////
16 
18 
19 namespace TMVA
20 {
21 namespace DNN
22 {
23 
24 //______________________________________________________________________________
25 template<typename Real_t>
27 {
28  size_t m,n;
29  m = W.GetNrows();
30  n = W.GetNcols();
31 
32  Real_t result = 0.0;
33 
34  for (size_t i = 0; i < m; i++) {
35  for (size_t j = 0; j < n; j++) {
36  result += std::abs(W(i,j));
37  }
38  }
39  return result;
40 }
41 
42 //______________________________________________________________________________
43 template<typename Real_t>
45  const TMatrixT<Real_t> & W,
47 {
48  size_t m,n;
49  m = W.GetNrows();
50  n = W.GetNcols();
51 
52  Real_t sign = 0.0;
53 
54  for (size_t i = 0; i < m; i++) {
55  for (size_t j = 0; j < n; j++) {
56  sign = (W(i,j) > 0.0) ? 1.0 : -1.0;
57  A(i,j) += sign * weightDecay;
58  }
59  }
60 }
61 
62 //______________________________________________________________________________
63 template<typename Real_t>
65 {
66  size_t m,n;
67  m = W.GetNrows();
68  n = W.GetNcols();
69 
70  Real_t result = 0.0;
71 
72  for (size_t i = 0; i < m; i++) {
73  for (size_t j = 0; j < n; j++) {
74  result += W(i,j) * W(i,j);
75  }
76  }
77  return result;
78 }
79 
80 //______________________________________________________________________________
81 template<typename Real_t>
83  const TMatrixT<Real_t> & W,
85 {
86  size_t m,n;
87  m = W.GetNrows();
88  n = W.GetNcols();
89 
90  for (size_t i = 0; i < m; i++) {
91  for (size_t j = 0; j < n; j++) {
92  A(i,j) += weightDecay * 2.0 * W(i,j);
93  }
94  }
95 }
96 
97 } // namespace DNN
98 } // namespace TMVA
auto * m
Definition: textangle.C:8
Int_t GetNcols() const
Definition: TMatrixTBase.h:125
static double A[]
TMatrixT.
Definition: TMatrixDfwd.h:22
static void AddL2RegularizationGradients(TMatrixT< AReal > &A, const TMatrixT< AReal > &W, AReal weightDecay)
double weightDecay(double error, ItWeight itWeight, ItWeight itWeightEnd, double factorWeightDecay, EnumRegularization eRegularization)
compute the weight decay for regularization (L1 or L2)
Definition: NeuralNet.icc:496
static AReal L1Regularization(const TMatrixT< AReal > &W)
static void AddL1RegularizationGradients(TMatrixT< AReal > &A, const TMatrixT< AReal > &W, AReal weightDecay)
static AReal L2Regularization(const TMatrixT< AReal > &W)
Int_t GetNrows() const
Definition: TMatrixTBase.h:122
float Real_t
Definition: RtypesCore.h:64
Abstract ClassifierFactory template that handles arbitrary types.
const Int_t n
Definition: legend1.C:16