ROOT   Reference Guide
Searching...
No Matches
BatchNormLayer.h
Go to the documentation of this file.
1
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : TBatchNormLayer *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Dense Layer Class *
12 * *
13 * Authors (alphabetical): *
15 * *
16 * Copyright (c) 2005-2015: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * U. of Bonn, Germany *
21 * *
22 * Redistribution and use in source and binary forms, with or without *
23 * modification, are permitted according to the terms listed in LICENSE *
25 **********************************************************************************/
26
27#ifndef TMVA_DNN_BatchNormLayer
28#define TMVA_DNN_BatchNormLayer
29
31#include "TMVA/DNN/Functions.h"
32
34
36
37#include <iostream>
38#include <iomanip>
39#include <vector>
40
41namespace TMVA {
42namespace DNN {
43
44/** \class TBatchNormLayer
45
46 Layer implementing Batch Normalization
47
48 The input from each batch are normalized during training to have zero mean and unit variance
49 and they are then scaled by two parameter, different for each input variable:
50 - a scale factor gamma
51 - an offset beta
52
53 In addition a running batch mean and variance is computed and stored in the class
54 During inference the inputs are not normalized using the batch mean but the previously computed
55 at running mean and variance
56 If momentum is in [0,1) the running mean and variances are the exponential averages using the momentum value
57 running_mean = momentum * running_mean + (1-momentum) * batch_mean
58 If instead momentum<1 the cumulative average is computed
59 running_mean = (nb/(nb+1) * running_mean + 1/(nb+1) * batch_mean
60
61 See more at [https://arxiv.org/pdf/1502.03167v3.pdf]
62*/
63template <typename Architecture_t>
64class TBatchNormLayer : public VGeneralLayer<Architecture_t> {
65public:
66
67 using Scalar_t = typename Architecture_t::Scalar_t;
68 using Matrix_t = typename Architecture_t::Matrix_t;
69 using Tensor_t = typename Architecture_t::Tensor_t;
70
71 using HelperDescriptor_t = typename Architecture_t::TensorDescriptor_t;
72 using BNormDescriptors_t = typename Architecture_t::BNormDescriptors_t;
73
74
75private:
76
77 Tensor_t fDerivatives; ///< First fDerivatives of the activations of this layer.
78
79 int fNormAxis; ///< Normalization axis. For each element of this axis we will compute mean and stddev
80
81 Scalar_t fMomentum; ///< The weight decay.
83
87
90
91 // cached tensor used for Cudnn to get correct shape
92 Tensor_t fReshapedData; // cached reshaped data tensor
93
94 // counter of trained batches for computing testing and variance means
96
98
99public:
100 /*! Constructor */
101 TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
102 const std::vector<size_t> & shape, int axis = -1, Scalar_t momentum = -1., Scalar_t epsilon = 0.0001);
103
104 /*! Copy the dense layer provided as a pointer */
106
107 /*! Copy Constructor */
109
110 /*! Destructor */
112
113 /*! Compute activation of the layer for the given input. The input
114 * must be in 3D tensor form with the different matrices corresponding to
115 * different events in the batch. Computes activations as well as
116 * the first partial derivative of the activation function at those
117 * activations. */
118 void Forward(Tensor_t &input, bool inTraining = true);
119
120 /*! Compute weight, bias and activation gradients. Uses the precomputed
121 * first partial derivatives of the activation function computed during
122 * forward propagation and modifies them. Must only be called directly
123 * a the corresponding call to Forward(...). */
124 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
125 // Tensor_t &inp1, Tensor_t &inp2);
126
127
128 /* reset at end of training the batch counter */
130
131 /*! Printing the layer info. */
132 void Print() const;
133
134 /*! Writes the information and the weights about the layer in an XML node. */
136
137 /*! Read the information and the weights about the layer from XML node. */
139
140 /* initialize weights */
141 virtual void Initialize();
142
143 /* get number of trained batches */
144 const int & GetNTrainedBatches() const { return fTrainedBatches;}
146
147 /* get batch means for the training phase */
148 const Matrix_t & GetBatchMean() const { return fMu;}
149 Matrix_t & GetBatchMean() { return fMu;}
150
151 /* Get the normalized batch examples */
152 //const Matrix_t & GetNormedBatch() const { return fXhat;}
153 //Matrix_t & GetNormedBatch() { return fXhat;}
154
155 /* Get the gradient of gamma for backpropagation */
156 const Matrix_t & GetVariance() const { return fVar;}
157 Matrix_t & GetVariance() { return fVar;}
158
159 /* Get the sqrt of the batch variances for the training phase */
160 const Matrix_t & GetIVariance() const { return fIVar;}
162
163 /* get vector of averages computed in the training phase */
164 const Matrix_t & GetMuVector() const { return fMu_Training;}
166
167 /* get vector of variances computed in the training phase */
168 const Matrix_t & GetVarVector() const { return fVar_Training;}
170
171 // Scalar_t GetWeightDecay() const { return fWeightDecay; }
172
173 /* Get the momentum of the running mean/variance */
174 Scalar_t GetMomentum() const { return fMomentum;}
175
176 /* Get epsilon */
177 Scalar_t GetEpsilon() const { return fEpsilon;}
178
179 /* Get normalization axis (the one which will have each element normalized) */
180 Scalar_t GetNormAxis() const { return fNormAxis;}
181
182 const Matrix_t &GetReshapedData() const { return fReshapedData; }
184
185 std::vector<Matrix_t> GetExtraLayerParameters() const {
186 std::vector<Matrix_t> params(2);
187 params[0] = this->GetMuVector();
188 params[1] = this->GetVarVector();
189 return params;
190 }
191
192 void SetExtraLayerParameters(const std::vector<Matrix_t> & params)
193 {
194 this->GetMuVector() = params[0];
195 this->GetVarVector() = params[1];
196 }
197
198protected:
199 static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
200 {
201 if (axis == -1)
202 return c * h * w;
203 else if (axis == 1)
204 return c;
205 else if (axis == 2)
206 return h;
207 else if (axis == 3)
208 return w;
209 return 0;
210 }
211};
212
213
214//
215//
216// The Dense Layer Class - Implementation
217//______________________________________________________________________________
218template <typename Architecture_t>
219TBatchNormLayer<Architecture_t>::TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight,
220 size_t inputWidth, const std::vector<size_t> &shape, int axis,
221 Scalar_t momentum, Scalar_t epsilon)
222 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, // bs + input shape
223 inputDepth, inputHeight, inputWidth, // output shape
224 2, 1,
225 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth), // weight tensor dim.
226 1, 1, 1, // bias
227 shape[2], shape[0], shape[1], // output tensor shape as bsize, depth, hw
229 fNormAxis(axis), fMomentum(momentum), fEpsilon(epsilon),
230 fMu(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()), // dimension is same as weights
231 fVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
232 fIVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
233 fMu_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
234 fVar_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
235 fReshapedData(1,1,1) // use a dummy single element tensor
236
237{
238
239}
240//______________________________________________________________________________
241template <typename Architecture_t>
243 : VGeneralLayer<Architecture_t>(layer)
244{
245 // to be implemented
246 printf("Error - copy ctor not implemented\n");
247}
248
249//______________________________________________________________________________
250template <typename Architecture_t>
252{
253 // to be implemented
254 printf("Error - copy ctor not implemented\n");
255}
256
257//______________________________________________________________________________
258template <typename Architecture_t>
260{
261 // release descriptors
262 if (fDescriptors) {
263 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
264 delete fDescriptors;
265 }
266}
267
268template <typename Architecture_t>
270{
271 Matrix_t &gamma = this->GetWeightsAt(0);
272 Matrix_t &beta = this->GetWeightsAt(1);
273 size_t bndim = gamma.GetNcols();
274
275 initialize<Architecture_t>(beta, EInitialization::kZero);
276 for (size_t i = 0; i < bndim; ++i) {
277 gamma(0, i) = 1.;
278 // assign default values for the other parameters
279 fMu_Training(0,i) = 0;
280 fVar_Training(0,i) = 1;
281 }
282
285 initialize<Architecture_t>(dgamma, EInitialization::kZero);
286 initialize<Architecture_t>(dbeta, EInitialization::kZero);
287
288 fTrainedBatches = 0;
289
290 Architecture_t::InitializeBNormDescriptors(fDescriptors, this);
291}
292
293//______________________________________________________________________________
294template <typename Architecture_t>
296{
297 Tensor_t x2;
298 Tensor_t y2;
299 if (x.GetLayout() != fReshapedData.GetLayout()) {
300 x2 = Tensor_t(x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
301 y2 = Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
302 }
303 else{
304 x2 = x;
305 y2 = this->GetOutput();
306 }
307
308 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
309 if (inTraining) {
310 Architecture_t::BatchNormLayerForwardTraining(fNormAxis, x2, y2,
311 this->GetWeightsAt(0), this->GetWeightsAt(1),
312 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
313 this->GetMuVector(),
314 this->GetVarVector(), this->GetNTrainedBatches(),
315 this->GetMomentum(), this->GetEpsilon(),
316 descr->HelperDescriptor);
317 fTrainedBatches++;
318 }
319
320 else {
321 // if (fTrainedBatches > 0) {
322 // Architecture_t::PrintTensor(Tensor_t(this->GetWeightsAt(0)), "bnorm gamma");
323 // Architecture_t::PrintTensor(Tensor_t(this->GetWeightsAt(1)), "bnorm beta");
324 // Architecture_t::PrintTensor(Tensor_t(this->GetMuVector()), "bnorm mu");
325 // Architecture_t::PrintTensor(Tensor_t(this->GetVarVector()), "bnorm var");
326 // }
327 Architecture_t::BatchNormLayerForwardInference(fNormAxis, x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
328 y2, this->GetMuVector(), this->GetVarVector(),
329 this->GetEpsilon(), descr->HelperDescriptor);
330 fTrainedBatches = 0;
331 }
332
333}
334
335//______________________________________________________________________________
336template <typename Architecture_t>
338 const Tensor_t & activations_backward ) -> void
339// Tensor_t &, Tensor_t &) -> void
340{
341 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
342
343
344 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
345 Tensor_t x = Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
346 Tensor_t dx = Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
347 Tensor_t dy = Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
348
349 Architecture_t::BatchNormLayerBackward(fNormAxis, x, dy, dx,
350 this->GetWeightsAt(0), // gamma (beta is not needed)
352 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
353 this->GetEpsilon(), descr->HelperDescriptor);
354
355 } else {
356
357 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward, // x
360 this->GetWeightsAt(0), // gamma (beta is not needed)
362 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
363 this->GetEpsilon(), descr->HelperDescriptor);
364 }
365}
366
367//______________________________________________________________________________
368template <typename Architecture_t>
370{
371 std::cout << " BATCH NORM Layer: \t";
372 std::cout << " Input/Output = ( " ;
373 auto &shape = this->GetOutput().GetShape();
374 for (size_t i = 0; i < shape.size(); ++i) {
375 if (i > 0) std::cout << " , ";
376 std::cout << shape[i];
377 }
378 std::cout << " ) ";
379 std::cout << "\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
380 std::cout << "\t axis = " << fNormAxis << std::endl;
381 std::cout << std::endl;
382}
383
384//______________________________________________________________________________
385
386template <typename Architecture_t>
388{
389
390 // write layer width activation function + weight and bias matrices
391
392 auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "BatchNormLayer");
393
394
397
398 // write stored mean and variances
399 //using Scalar_t = typename Architecture_t::Scalar_t;
400
401 this->WriteMatrixToXML(layerxml, "Training-mu", this->GetMuVector());
402 this->WriteMatrixToXML(layerxml, "Training-variance", this->GetVarVector());
403
404 // write weights (gamma and beta)
405 this->WriteMatrixToXML(layerxml, "Gamma", this->GetWeightsAt(0));
406 this->WriteMatrixToXML(layerxml, "Beta", this->GetWeightsAt(1));
407
408}
409
410//______________________________________________________________________________
411template <typename Architecture_t>
413{
414 // momentum and epsilon can be added after constructing the class
417 // Read layer weights and biases from XML
418
421
424}
425
426} // namespace DNN
427} // namespace TMVA
428
429#endif
#define c(i)
Definition RSha256.hxx:101
#define h(i)
Definition RSha256.hxx:106
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char x2
Option_t Option_t TPoint TPoint const char y2
Layer implementing Batch Normalization.
static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
const Matrix_t & GetMuVector() const
int fNormAxis
Normalization axis. For each element of this axis we will compute mean and stddev.
typename Architecture_t::Matrix_t Matrix_t
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
void SetExtraLayerParameters(const std::vector< Matrix_t > &params)
void ResetTraining()
Reset some training flags after a loop on all batches Some layer (e.g.
std::vector< Matrix_t > GetExtraLayerParameters() const
typename Architecture_t::Tensor_t Tensor_t
Writes the information and the weights about the layer in an XML node.
Scalar_t fMomentum
The weight decay.
const Matrix_t & GetVariance() const
const Matrix_t & GetReshapedData() const
void Print() const
Printing the layer info.
Read the information and the weights about the layer from XML node.
const int & GetNTrainedBatches() const
const Matrix_t & GetIVariance() const
typename Architecture_t::Scalar_t Scalar_t
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Compute weight, bias and activation gradients.
void Forward(Tensor_t &input, bool inTraining=true)
Compute activation of the layer for the given input.
typename Architecture_t::TensorDescriptor_t HelperDescriptor_t
typename Architecture_t::BNormDescriptors_t BNormDescriptors_t
const Matrix_t & GetVarVector() const
Tensor_t fDerivatives
First fDerivatives of the activations of this layer.
const Matrix_t & GetBatchMean() const
TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth, const std::vector< size_t > &shape, int axis=-1, Scalar_t momentum=-1., Scalar_t epsilon=0.0001)
Constructor.
Generic General Layer class.
TXMLEngine & xmlengine()
Definition Tools.h:262
void ReadAttr(void *node, const char *, T &value)