Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Regularization.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 21/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12///////////////////////////////////////////////////////////////////////
13// Implementation of the regularization functionals and gradients //
14// for the multi-threaded CPU implementation using Roots TThreadExecutor. //
15///////////////////////////////////////////////////////////////////////
16
19
20namespace TMVA
21{
22namespace DNN
23{
24
25//______________________________________________________________________________
26template<typename AFloat>
28{
29 const AFloat *data = Weights.GetRawDataPointer();
30
31 size_t nElements = Weights.GetNoElements();
33
34 std::vector<AFloat> temp(nElements/nSteps + 1);
35
36 auto f = [&data, &temp, nElements, nSteps](UInt_t workerID)
37 {
38 size_t iMax = std::min(workerID+nSteps, nElements);
39 size_t iWorker = workerID/nSteps;
40 for (size_t i = workerID; i < iMax; ++i) {
41 temp[iWorker] += fabs(data[i]);
42 }
43 };
44
45 auto reduction = [](const std::vector<AFloat> & v )
46 {
47 return std::accumulate(v.begin(),v.end(),AFloat{});
48 };
49 // auto reduction = [](AFloat sum1, AFloat sum2)
50 // {
51 // return sum1 + sum2;
52 // };
53 Weights.GetThreadExecutor().Foreach(f, ROOT::TSeqI(0,nElements,nSteps) );
54 return Weights.GetThreadExecutor().Reduce(temp, reduction);
55}
56
57
58//______________________________________________________________________________
59template<typename AFloat>
62 const TCpuMatrix<AFloat> & A,
63 AFloat weightDecay)
64{
65 AFloat *dataB = B.GetRawDataPointer();
66 const AFloat *dataA = A.GetRawDataPointer();
67
68 size_t nElements = B.GetNoElements();
69 R__ASSERT(A.GetNoElements() == nElements);
71
72
73
75 {
76 size_t iMax = std::min(workerID+nSteps, nElements);
77 for (size_t i = workerID; i < iMax; ++i) {
78 AFloat sign = (dataA[i] < 0.0) ? -1.0 : 1.0;
79 dataB[i] += weightDecay * sign;
80 }
81 return 0;
82 };
83
84 if (nSteps < nElements) {
85#ifdef DL_USE_MTE
86 B.GetThreadExecutor().Foreach(f, ROOT::TSeqI(0,nElements, nSteps));
87#else
88 for (size_t i = 0; i < nElements; i+=nSteps)
89 f(i);
90#endif
91 } else {
92 f(0);
93 }
94}
95
96//______________________________________________________________________________
97template<typename AFloat>
99{
100 const AFloat *data = Weights.GetRawDataPointer();
101
102 size_t nElements = Weights.GetNoElements();
104
105 std::vector<AFloat> temp(nElements/nSteps + 1);
106
107 auto f = [&data, &temp, nElements, nSteps](UInt_t workerID)
108 {
109 size_t iMax = std::min(workerID+nSteps, nElements);
110 size_t iWorker = workerID/nSteps;
111
112 for (size_t i = workerID; i < iMax; ++i) {
113 temp[iWorker] += data[i] * data[i];
114 }
115 };
116
117 auto reduction = [](const std::vector<AFloat> & v )
118 {
119 return std::accumulate(v.begin(),v.end(),AFloat{});
120 };
121 // auto reduction = [](AFloat sum1, AFloat sum2)
122 // {
123 // return sum1 + sum2;
124 // };
125
126 Weights.GetThreadExecutor().Foreach(f, ROOT::TSeqI(0,nElements,nSteps) );
127 return Weights.GetThreadExecutor().Reduce(temp, reduction);
128}
129
130//______________________________________________________________________________
131template<typename AFloat>
134 const TCpuMatrix<AFloat> & A,
135 AFloat weightDecay)
136{
137 AFloat *dataB = B.GetRawDataPointer();
138 const AFloat *dataA = A.GetRawDataPointer();
139
140 size_t nElements = B.GetNoElements();
141 R__ASSERT(A.GetNoElements() == nElements);
143
145 {
146 size_t iMax = std::min(workerID+nSteps, nElements);
147 for (size_t i = workerID; i < iMax; ++i) {
148 dataB[i] += 2.0 * weightDecay * dataA[i];
149 }
150 return 0;
151 };
152
153 if (nSteps < nElements) {
154#ifdef DL_USE_MTE
155 B.GetThreadExecutor().Foreach(f, ROOT::TSeqI(0,nElements, nSteps));
156#else
157 for (size_t i = 0; i < nElements; i+=nSteps)
158 f(i);
159#endif
160 } else {
161 f(0);
162 }
163}
164
165
166} // namespace DNN
167} // namespace TMVA
#define f(i)
Definition RSha256.hxx:104
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
#define R__ASSERT(e)
Checks condition e and reports a fatal error if it's false.
Definition TError.h:125
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
static size_t GetNWorkItems(size_t nelements)
Definition CpuMatrix.h:191
static Scalar_t L1Regularization(const Matrix_t &W)
static void AddL1RegularizationGradients(Matrix_t &A, const Matrix_t &W, Scalar_t weightDecay)
static Scalar_t L2Regularization(const Matrix_t &W)
static void AddL2RegularizationGradients(Matrix_t &A, const Matrix_t &W, Scalar_t weightDecay)
double weightDecay(double error, ItWeight itWeight, ItWeight itWeightEnd, double factorWeightDecay, EnumRegularization eRegularization)
compute the weight decay for regularization (L1 or L2)
create variable transformations