Logo ROOT  
Reference Guide
BatchNormLayer.h
Go to the documentation of this file.
1
2// Author: Vladimir Ilievski
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : TBatchNormLayer *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Dense Layer Class *
12 * *
13 * Authors (alphabetical): *
14 * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15 * *
16 * Copyright (c) 2005-2015: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * U. of Bonn, Germany *
21 * *
22 * Redistribution and use in source and binary forms, with or without *
23 * modification, are permitted according to the terms listed in LICENSE *
24 * (http://tmva.sourceforge.net/LICENSE) *
25 **********************************************************************************/
26
27#ifndef TMVA_DNN_BatchNormLayer
28#define TMVA_DNN_BatchNormLayer
29
31#include "TMVA/DNN/Functions.h"
32
34
36
37#include <iostream>
38#include <iomanip>
39#include <vector>
40
41namespace TMVA {
42namespace DNN {
43
44/** \class TBatchNormLayer
45
46 Layer implementing Batch Normalization
47
48 The input from each batch are normalized during training to have zero mean and unit variance
49 and they are then scaled by two parameter, different for each input variable:
50 - a scale factor gamma
51 - an offset beta
52
53 In addition a running batch mean and variance is computed and stored in the class
54 During inference the inputs are not normalized using the batch mean but the previously computed
55 at running mean and variance
56 If momentum is in [0,1) the running mean and variances are the exponetial averages using the momentum value
57 runnig_mean = momentum * running_mean + (1-momentum) * batch_mean
58 If instead momentum<1 the cumulative average is computed
59 running_mean = (nb/(nb+1) * running_mean + 1/(nb+1) * batch_mean
60
61 See more at [https://arxiv.org/pdf/1502.03167v3.pdf]
62*/
63template <typename Architecture_t>
64class TBatchNormLayer : public VGeneralLayer<Architecture_t> {
65public:
66
67 using Scalar_t = typename Architecture_t::Scalar_t;
68 using Matrix_t = typename Architecture_t::Matrix_t;
69 using Tensor_t = typename Architecture_t::Tensor_t;
70
71 using HelperDescriptor_t = typename Architecture_t::TensorDescriptor_t;
72 using BNormDescriptors_t = typename Architecture_t::BNormDescriptors_t;
73
74
75private:
76
77 Tensor_t fDerivatives; ///< First fDerivatives of the activations of this layer.
78
79 int fNormAxis; ///< Normalization axis. For each element of this axis we will compute mean and stddev
80
81 Scalar_t fMomentum; ///< The weight decay.
83
87
90
91 // cached tensor used for Cudnn to get correct shape
92 Tensor_t fReshapedData; // cached reshaped data tensor
93
94 // counter of trained batches for computing tesing and variance means
96
98
99public:
100 /*! Constructor */
101 TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
102 const std::vector<size_t> & shape, int axis = -1, Scalar_t momentum = -1., Scalar_t epsilon = 0.0001);
103
104 /*! Copy the dense layer provided as a pointer */
106
107 /*! Copy Constructor */
109
110 /*! Destructor */
112
113 /*! Compute activation of the layer for the given input. The input
114 * must be in 3D tensor form with the different matrices corresponding to
115 * different events in the batch. Computes activations as well as
116 * the first partial derivative of the activation function at those
117 * activations. */
118 void Forward(Tensor_t &input, bool inTraining = true);
119
120 /*! Compute weight, bias and activation gradients. Uses the precomputed
121 * first partial derviatives of the activation function computed during
122 * forward propagation and modifies them. Must only be called directly
123 * a the corresponding call to Forward(...). */
124 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
125 // Tensor_t &inp1, Tensor_t &inp2);
126
127
128 /* reset at end of training the batch counter */
130
131 /*! Printing the layer info. */
132 void Print() const;
133
134 /*! Writes the information and the weights about the layer in an XML node. */
135 virtual void AddWeightsXMLTo(void *parent);
136
137 /*! Read the information and the weights about the layer from XML node. */
138 virtual void ReadWeightsFromXML(void *parent);
139
140 /* initialize weights */
141 virtual void Initialize();
142
143 /* get number of trained batches */
144 const int & GetNTrainedBatches() const { return fTrainedBatches;}
146
147 /* get batch means for the training phase */
148 const Matrix_t & GetBatchMean() const { return fMu;}
149 Matrix_t & GetBatchMean() { return fMu;}
150
151 /* Get the normalized batch examples */
152 //const Matrix_t & GetNormedBatch() const { return fXhat;}
153 //Matrix_t & GetNormedBatch() { return fXhat;}
154
155 /* Get the gradient of gamma for backpropagation */
156 const Matrix_t & GetVariance() const { return fVar;}
157 Matrix_t & GetVariance() { return fVar;}
158
159 /* Get the sqrt of the batch variances for the training phase */
160 const Matrix_t & GetIVariance() const { return fIVar;}
162
163 /* get vector of averages computed in the training phase */
164 const Matrix_t & GetMuVector() const { return fMu_Training;}
166
167 /* get vector of variances computed in the training phase */
168 const Matrix_t & GetVarVector() const { return fVar_Training;}
170
171 // Scalar_t GetWeightDecay() const { return fWeightDecay; }
172
173 /* Get the momentum of the running mean/variance */
174 Scalar_t GetMomentum() const { return fMomentum;}
175
176 /* Get epsilon */
177 Scalar_t GetEpsilon() const { return fEpsilon;}
178
179 /* Get normalization axis (the one which will have each element normalized) */
180 Scalar_t GetNormAxis() const { return fNormAxis;}
181
182 const Matrix_t &GetReshapedData() const { return fReshapedData; }
184
185 std::vector<Matrix_t> GetExtraLayerParameters() const {
186 std::vector<Matrix_t> params(2);
187 params[0] = this->GetMuVector();
188 params[1] = this->GetVarVector();
189 return params;
190 }
191
192 void SetExtraLayerParameters(const std::vector<Matrix_t> & params)
193 {
194 this->GetMuVector() = params[0];
195 this->GetVarVector() = params[1];
196 }
197
198protected:
199 static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
200 {
201 if (axis == -1)
202 return c * h * w;
203 else if (axis == 1)
204 return c;
205 else if (axis == 2)
206 return h;
207 else if (axis == 3)
208 return w;
209 return 0;
210 }
211};
212
213
214//
215//
216// The Dense Layer Class - Implementation
217//______________________________________________________________________________
218template <typename Architecture_t>
219TBatchNormLayer<Architecture_t>::TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight,
220 size_t inputWidth, const std::vector<size_t> &shape, int axis,
221 Scalar_t momentum, Scalar_t epsilon)
222 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, // bs + input shape
223 inputDepth, inputHeight, inputWidth, // output shape
224 2, 1,
225 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth), // weight tensor dim.
226 1, 1, 1, // bias
227 shape[2], shape[0], shape[1], // output tensor shape as bsize, depth, hw
229 fNormAxis(axis), fMomentum(momentum), fEpsilon(epsilon),
230 fMu(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()), // dimension is same as weights
231 fVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
232 fIVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
233 fMu_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
234 fVar_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
235 fReshapedData(1,1,1) // use a dummy single element tensor
236
237{
238
239}
240//______________________________________________________________________________
241template <typename Architecture_t>
243 : VGeneralLayer<Architecture_t>(layer)
244{
245 // to be implemented
246 printf("Error - copy ctor not implmented\n");
247}
248
249//______________________________________________________________________________
250template <typename Architecture_t>
252{
253 // to be implmeented
254 printf("Error - copy ctor not implmented\n");
255}
256
257//______________________________________________________________________________
258template <typename Architecture_t>
260{
261 // release descriptors
262 if (fDescriptors) {
263 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
264 delete fDescriptors;
265 }
266}
267
268template <typename Architecture_t>
270{
271 Matrix_t &gamma = this->GetWeightsAt(0);
272 Matrix_t &beta = this->GetWeightsAt(1);
273 size_t bndim = gamma.GetNcols();
274
275 initialize<Architecture_t>(beta, EInitialization::kZero);
276 for (size_t i = 0; i < bndim; ++i) {
277 gamma(0, i) = 1.;
278 // assign default values for the other parameters
279 fMu_Training(0,i) = 0;
280 fVar_Training(0,i) = 1;
281 }
282
283 Matrix_t &dgamma = this->GetWeightGradientsAt(0);
284 Matrix_t &dbeta = this->GetWeightGradientsAt(1);
285 initialize<Architecture_t>(dgamma, EInitialization::kZero);
286 initialize<Architecture_t>(dbeta, EInitialization::kZero);
287
288 fTrainedBatches = 0;
289
290 Architecture_t::InitializeBNormDescriptors(fDescriptors, this);
291}
292
293//______________________________________________________________________________
294template <typename Architecture_t>
296{
297 Tensor_t x2;
298 Tensor_t y2;
299 if (x.GetLayout() != fReshapedData.GetLayout()) {
300 x2 = Tensor_t(x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
301 y2 = Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
302 }
303 else{
304 x2 = x;
305 y2 = this->GetOutput();
306 }
307
308 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
309 if (inTraining) {
310 Architecture_t::BatchNormLayerForwardTraining(fNormAxis, x2, y2,
311 this->GetWeightsAt(0), this->GetWeightsAt(1),
312 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
313 this->GetMuVector(),
314 this->GetVarVector(), this->GetNTrainedBatches(),
315 this->GetMomentum(), this->GetEpsilon(),
316 descr->HelperDescriptor);
317 fTrainedBatches++;
318 }
319
320 else {
321 // if (fTrainedBatches > 0) {
322 // Architecture_t::PrintTensor(Tensor_t(this->GetWeightsAt(0)), "bnorm gamma");
323 // Architecture_t::PrintTensor(Tensor_t(this->GetWeightsAt(1)), "bnorm beta");
324 // Architecture_t::PrintTensor(Tensor_t(this->GetMuVector()), "bnorm mu");
325 // Architecture_t::PrintTensor(Tensor_t(this->GetVarVector()), "bnorm var");
326 // }
327 Architecture_t::BatchNormLayerForwardInference(fNormAxis, x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
328 y2, this->GetMuVector(), this->GetVarVector(),
329 this->GetEpsilon(), descr->HelperDescriptor);
330 fTrainedBatches = 0;
331 }
332
333}
334
335//______________________________________________________________________________
336template <typename Architecture_t>
338 const Tensor_t & activations_backward ) -> void
339// Tensor_t &, Tensor_t &) -> void
340{
341 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
342
343
344 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
345 Tensor_t x = Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
346 Tensor_t dx = Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
347 Tensor_t dy = Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
348
349 Architecture_t::BatchNormLayerBackward(fNormAxis, x, dy, dx,
350 this->GetWeightsAt(0), // gamma (beta is not needed)
351 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
352 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
353 this->GetEpsilon(), descr->HelperDescriptor);
354
355 } else {
356
357 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward, // x
358 this->GetActivationGradients(), // dy
359 gradients_backward, // dx
360 this->GetWeightsAt(0), // gamma (beta is not needed)
361 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
362 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
363 this->GetEpsilon(), descr->HelperDescriptor);
364 }
365}
366
367//______________________________________________________________________________
368template <typename Architecture_t>
370{
371 std::cout << " BATCH NORM Layer: \t";
372 std::cout << " Input/Output = ( " ;
373 auto &shape = this->GetOutput().GetShape();
374 for (size_t i = 0; i < shape.size(); ++i) {
375 if (i > 0) std::cout << " , ";
376 std::cout << shape[i];
377 }
378 std::cout << " ) ";
379 std::cout << "\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
380 std::cout << "\t axis = " << fNormAxis << std::endl;
381 std::cout << std::endl;
382}
383
384//______________________________________________________________________________
385
386template <typename Architecture_t>
388{
389
390 // write layer width activation function + weigbht and bias matrices
391
392 auto layerxml = gTools().xmlengine().NewChild(parent, 0, "BatchNormLayer");
393
394
395 gTools().AddAttr(layerxml, "Momentum", fMomentum);
396 gTools().AddAttr(layerxml, "Epsilon", fEpsilon);
397
398 // write stored mean and variances
399 //using Scalar_t = typename Architecture_t::Scalar_t;
400
401 this->WriteMatrixToXML(layerxml, "Training-mu", this->GetMuVector());
402 this->WriteMatrixToXML(layerxml, "Training-variance", this->GetVarVector());
403
404 // write weights (gamma and beta)
405 this->WriteMatrixToXML(layerxml, "Gamma", this->GetWeightsAt(0));
406 this->WriteMatrixToXML(layerxml, "Beta", this->GetWeightsAt(1));
407
408}
409
410//______________________________________________________________________________
411template <typename Architecture_t>
413{
414 // momentum and epsilon can be added after constructing the class
415 gTools().ReadAttr(parent, "Momentum", fMomentum);
416 gTools().ReadAttr(parent, "Epsilon", fEpsilon);
417 // Read layer weights and biases from XML
418
419 this->ReadMatrixXML(parent, "Training-mu", this->GetMuVector());
420 this->ReadMatrixXML(parent, "Training-variance", this->GetVarVector());
421
422 this->ReadMatrixXML(parent, "Gamma", this->GetWeightsAt(0));
423 this->ReadMatrixXML(parent, "Beta", this->GetWeightsAt(1));
424}
425
426} // namespace DNN
427} // namespace TMVA
428
429#endif
#define c(i)
Definition: RSha256.hxx:101
#define h(i)
Definition: RSha256.hxx:106
static const double x2[5]
Layer implementing Batch Normalization.
static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
const Matrix_t & GetMuVector() const
int fNormAxis
Normalization axis. For each element of this axis we will compute mean and stddev.
typename Architecture_t::Matrix_t Matrix_t
Scalar_t GetMomentum() const
Scalar_t GetEpsilon() const
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
Scalar_t GetNormAxis() const
void SetExtraLayerParameters(const std::vector< Matrix_t > &params)
void ResetTraining()
Reset some training flags after a loop on all batches Some layer (e.g.
std::vector< Matrix_t > GetExtraLayerParameters() const
typename Architecture_t::Tensor_t Tensor_t
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Scalar_t fMomentum
The weight decay.
const Matrix_t & GetVariance() const
const Matrix_t & GetReshapedData() const
void Print() const
Printing the layer info.
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
const int & GetNTrainedBatches() const
const Matrix_t & GetIVariance() const
typename Architecture_t::Scalar_t Scalar_t
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Compute weight, bias and activation gradients.
void Forward(Tensor_t &input, bool inTraining=true)
Compute activation of the layer for the given input.
typename Architecture_t::TensorDescriptor_t HelperDescriptor_t
typename Architecture_t::BNormDescriptors_t BNormDescriptors_t
const Matrix_t & GetVarVector() const
Tensor_t fDerivatives
First fDerivatives of the activations of this layer.
const Matrix_t & GetBatchMean() const
TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth, const std::vector< size_t > &shape, int axis=-1, Scalar_t momentum=-1., Scalar_t epsilon=0.0001)
Constructor.
Generic General Layer class.
Definition: GeneralLayer.h:51
TXMLEngine & xmlengine()
Definition: Tools.h:268
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
Definition: TXMLEngine.cxx:715
double beta(double x, double y)
Calculates the beta function.
Double_t x[n]
Definition: legend1.C:17
double gamma(double x)
EInitialization
Definition: Functions.h:72
create variable transformations
Tools & gTools()
REAL epsilon
Definition: triangle.c:618