Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Regularization.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 10/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 //////////////////////////////////////////////////////////////////////
13 // Implementation of the regularization functions for the reference //
14 // implementation. //
15 //////////////////////////////////////////////////////////////////////
16
18
19namespace TMVA
20{
21namespace DNN
22{
23
24//______________________________________________________________________________
25template<typename Real_t>
27{
28 size_t m,n;
29 m = W.GetNrows();
30 n = W.GetNcols();
31
32 Real_t result = 0.0;
33
34 for (size_t i = 0; i < m; i++) {
35 for (size_t j = 0; j < n; j++) {
36 result += std::abs(W(i,j));
37 }
38 }
39 return result;
40}
41
42//______________________________________________________________________________
43template<typename Real_t>
45 const TMatrixT<Real_t> & W,
47{
48 size_t m,n;
49 m = W.GetNrows();
50 n = W.GetNcols();
51
52 Real_t sign = 0.0;
53
54 for (size_t i = 0; i < m; i++) {
55 for (size_t j = 0; j < n; j++) {
56 sign = (W(i,j) > 0.0) ? 1.0 : -1.0;
57 A(i,j) += sign * weightDecay;
58 }
59 }
60}
61
62//______________________________________________________________________________
63template<typename Real_t>
65{
66 size_t m,n;
67 m = W.GetNrows();
68 n = W.GetNcols();
69
70 Real_t result = 0.0;
71
72 for (size_t i = 0; i < m; i++) {
73 for (size_t j = 0; j < n; j++) {
74 result += W(i,j) * W(i,j);
75 }
76 }
77 return result;
78}
79
80//______________________________________________________________________________
81template<typename Real_t>
83 const TMatrixT<Real_t> & W,
85{
86 size_t m,n;
87 m = W.GetNrows();
88 n = W.GetNcols();
89
90 for (size_t i = 0; i < m; i++) {
91 for (size_t j = 0; j < n; j++) {
92 A(i,j) += weightDecay * 2.0 * W(i,j);
93 }
94 }
95}
96
97} // namespace DNN
98} // namespace TMVA
float Real_t
Definition RtypesCore.h:68
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
static void AddL1RegularizationGradients(TMatrixT< AReal > &A, const TMatrixT< AReal > &W, AReal weightDecay)
static AReal L2Regularization(const TMatrixT< AReal > &W)
static AReal L1Regularization(const TMatrixT< AReal > &W)
static void AddL2RegularizationGradients(TMatrixT< AReal > &A, const TMatrixT< AReal > &W, AReal weightDecay)
Int_t GetNrows() const
Int_t GetNcols() const
TMatrixT.
Definition TMatrixT.h:39
const Int_t n
Definition legend1.C:16
double weightDecay(double error, ItWeight itWeight, ItWeight itWeightEnd, double factorWeightDecay, EnumRegularization eRegularization)
compute the weight decay for regularization (L1 or L2)
create variable transformations
TMarker m
Definition textangle.C:8