4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : TBatchNormLayer *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Dense Layer Class *
12 * *
13 * Authors (alphabetical): *
15 * *
16 * Copyright (c) 2005-2015: *
17 * CERN, Switzerland *
18 * U. of Victoria, Canada *
19 * MPI-K Heidelberg, Germany *
20 * U. of Bonn, Germany *
21 * *
22 * Redistribution and use in source and binary forms, with or without *
23 * modification, are permitted according to the terms listed in LICENSE *
25 **********************************************************************************/
26
27#ifndef TMVA_DNN_BatchNormLayer
28#define TMVA_DNN_BatchNormLayer
29
31#include "TMVA/DNN/Functions.h"
32
34
36
37#include <iostream>
38#include <iomanip>
39
40namespace TMVA {
41namespace DNN {
42
43/** \class TBatchNormLayer
44
45 Layer implementing Batch Normalization
46
47 The input from each batch are normalized during training to have zero mean and unit variance
48 and they are then scaled by two parameter, different for each input variable:
49 - a scale factor gamma
50 - an offset beta
51
52 In addition a running batch mean and variance is computed and stored in the class
53 During inference the inputs are not normalized using the batch mean but the previously computed
54 at running mean and variance
55 If momentum is in [0,1) the running mean and variances are the exponetial averages using the momentum value
56 runnig_mean = momentum * running_mean + (1-momentum) * batch_mean
57 If instead momentum<1 the cumulative average is computed
58 running_mean = (nb/(nb+1) * running_mean + 1/(nb+1) * batch_mean
59
60 See more at [https://arxiv.org/pdf/1502.03167v3.pdf]
61*/
62template <typename Architecture_t>
63class TBatchNormLayer : public VGeneralLayer<Architecture_t> {
64public:
65
66 using Scalar_t = typename Architecture_t::Scalar_t;
67 using Matrix_t = typename Architecture_t::Matrix_t;
68 using Tensor_t = typename Architecture_t::Tensor_t;
69
70 using HelperDescriptor_t = typename Architecture_t::TensorDescriptor_t;
71 using BNormDescriptors_t = typename Architecture_t::BNormDescriptors_t;
72
73
74private:
75
76 Tensor_t fDerivatives; ///< First fDerivatives of the activations of this layer.
77
78 int fNormAxis; ///< Normalization axis. For each element of this axis we will compute mean and stddev
79
80 Scalar_t fMomentum; ///< The weight decay.
82
86
89
90 // cached tensor used for Cudnn to get correct shape
91 Tensor_t fReshapedData; // cached reshaped data tensor
92
93 // counter of trained batches for computing tesing and variance means
95
97
98public:
99 /*! Constructor */
100 TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
101 const std::vector<size_t> & shape, int axis = -1, Scalar_t momentum = -1., Scalar_t epsilon = 0.0001);
102
103 /*! Copy the dense layer provided as a pointer */
105
106 /*! Copy Constructor */
108
109 /*! Destructor */
111
112 /*! Compute activation of the layer for the given input. The input
113 * must be in 3D tensor form with the different matrices corresponding to
114 * different events in the batch. Computes activations as well as
115 * the first partial derivative of the activation function at those
116 * activations. */
117 void Forward(Tensor_t &input, bool inTraining = true);
118
119 /*! Compute weight, bias and activation gradients. Uses the precomputed
120 * first partial derviatives of the activation function computed during
121 * forward propagation and modifies them. Must only be called directly
122 * a the corresponding call to Forward(...). */
123 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
124 // Tensor_t &inp1, Tensor_t &inp2);
125
126
127 /* reset at end of training the batch counter */
129
130 /*! Printing the layer info. */
131 void Print() const;
132
133 /*! Writes the information and the weights about the layer in an XML node. */
135
136 /*! Read the information and the weights about the layer from XML node. */
138
139 /* initialize weights */
140 virtual void Initialize();
141
142 /* get number of trained batches */
143 const int & GetNTrainedBatches() const { return fTrainedBatches;}
145
146 /* get batch means for the training phase */
147 const Matrix_t & GetBatchMean() const { return fMu;}
148 Matrix_t & GetBatchMean() { return fMu;}
149
150 /* Get the normalized batch examples */
151 //const Matrix_t & GetNormedBatch() const { return fXhat;}
152 //Matrix_t & GetNormedBatch() { return fXhat;}
153
154 /* Get the gradient of gamma for backpropagation */
155 const Matrix_t & GetVariance() const { return fVar;}
156 Matrix_t & GetVariance() { return fVar;}
157
158 /* Get the sqrt of the batch variances for the training phase */
159 const Matrix_t & GetIVariance() const { return fIVar;}
161
162 /* get vector of averages computed in the training phase */
163 const Matrix_t & GetMuVector() const { return fMu_Training;}
165
166 /* get vector of variances computed in the training phase */
167 const Matrix_t & GetVarVector() const { return fVar_Training;}
169
170 // Scalar_t GetWeightDecay() const { return fWeightDecay; }
171
172 /* Get the momentum of the running mean/variance */
173 Scalar_t GetMomentum() const { return fMomentum;}
174
175 /* Get epsilon */
176 Scalar_t GetEpsilon() const { return fEpsilon;}
177
178 /* Get normalization axis (the one which will have each element normalized) */
179 Scalar_t GetNormAxis() const { return fNormAxis;}
180
181 const Matrix_t &GetReshapedData() const { return fReshapedData; }
183
184 std::vector<Matrix_t> GetExtraLayerParameters() const {
185 std::vector<Matrix_t> params(2);
186 params[0] = this->GetMuVector();
187 params[1] = this->GetVarVector();
188 return params;
189 }
190
191 void SetExtraLayerParameters(const std::vector<Matrix_t> & params)
192 {
193 this->GetMuVector() = params[0];
194 this->GetVarVector() = params[1];
195 }
196
197protected:
198 static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
199 {
200 if (axis == -1)
201 return c * h * w;
202 else if (axis == 1)
203 return c;
204 else if (axis == 2)
205 return h;
206 else if (axis == 3)
207 return w;
208 return 0;
209 }
210};
211
212
213//
214//
215// The Dense Layer Class - Implementation
216//______________________________________________________________________________
217template <typename Architecture_t>
218TBatchNormLayer<Architecture_t>::TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight,
219 size_t inputWidth, const std::vector<size_t> &shape, int axis,
220 Scalar_t momentum, Scalar_t epsilon)
221 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, // bs + input shape
222 inputDepth, inputHeight, inputWidth, // output shape
223 2, 1,
224 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth), // weight tensor dim.
225 1, 1, 1, // bias
226 shape[2], shape[0], shape[1], // output tensor shape as bsize, depth, hw
228 fNormAxis(axis), fMomentum(momentum), fEpsilon(epsilon),
229 fMu(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()), // dimension is same as weights
230 fVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
231 fIVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
232 fMu_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
233 fVar_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
234 fReshapedData(1,1,1) // use a dummy single element tensor
235
236{
237
238}
239//______________________________________________________________________________
240template <typename Architecture_t>
242 : VGeneralLayer<Architecture_t>(layer)
243{
244 // to be implemented
245 printf("Error - copy ctor not implmented\n");
246}
247
248//______________________________________________________________________________
249template <typename Architecture_t>
251{
252 // to be implmeented
253 printf("Error - copy ctor not implmented\n");
254}
255
256//______________________________________________________________________________
257template <typename Architecture_t>
259{
260 // release descriptors
261 if (fDescriptors) {
262 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
263 delete fDescriptors;
264 }
265}
266
267template <typename Architecture_t>
269{
270 Matrix_t &gamma = this->GetWeightsAt(0);
271 Matrix_t &beta = this->GetWeightsAt(1);
272 size_t bndim = gamma.GetNcols();
273
274 initialize<Architecture_t>(beta, EInitialization::kZero);
275 for (size_t i = 0; i < bndim; ++i) {
276 gamma(0, i) = 1.;
277 // assign default values for the other parameters
278 fMu_Training(0,i) = 0;
279 fVar_Training(0,i) = 1;
280 }
281
284 initialize<Architecture_t>(dgamma, EInitialization::kZero);
285 initialize<Architecture_t>(dbeta, EInitialization::kZero);
286
287 fTrainedBatches = 0;
288
289 Architecture_t::InitializeBNormDescriptors(fDescriptors, this);
290}
291
292//______________________________________________________________________________
293template <typename Architecture_t>
295{
296 Tensor_t x2;
297 Tensor_t y2;
298 if (x.GetLayout() != fReshapedData.GetLayout()) {
299 x2 = Tensor_t(x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
300 y2 = Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
301 }
302 else{
303 x2 = x;
304 y2 = this->GetOutput();
305 }
306
307 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
308 if (inTraining) {
309 Architecture_t::BatchNormLayerForwardTraining(fNormAxis, x2, y2,
310 this->GetWeightsAt(0), this->GetWeightsAt(1),
311 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
312 this->GetMuVector(),
313 this->GetVarVector(), this->GetNTrainedBatches(),
314 this->GetMomentum(), this->GetEpsilon(),
315 descr->HelperDescriptor);
316 fTrainedBatches++;
317 }
318
319 else {
320 // if (fTrainedBatches > 0) {
321 // Architecture_t::PrintTensor(Tensor_t(this->GetWeightsAt(0)), "bnorm gamma");
322 // Architecture_t::PrintTensor(Tensor_t(this->GetWeightsAt(1)), "bnorm beta");
323 // Architecture_t::PrintTensor(Tensor_t(this->GetMuVector()), "bnorm mu");
324 // Architecture_t::PrintTensor(Tensor_t(this->GetVarVector()), "bnorm var");
325 // }
326 Architecture_t::BatchNormLayerForwardInference(fNormAxis, x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
327 y2, this->GetMuVector(), this->GetVarVector(),
328 this->GetEpsilon(), descr->HelperDescriptor);
329 fTrainedBatches = 0;
330 }
331
332}
333
334//______________________________________________________________________________
335template <typename Architecture_t>
337 const Tensor_t & activations_backward ) -> void
338// Tensor_t &, Tensor_t &) -> void
339{
340 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
341
342
343 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
344 Tensor_t x = Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
345 Tensor_t dx = Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
346 Tensor_t dy = Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
347
348 Architecture_t::BatchNormLayerBackward(fNormAxis, x, dy, dx,
349 this->GetWeightsAt(0), // gamma (beta is not needed)
351 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
352 this->GetEpsilon(), descr->HelperDescriptor);
353
354 } else {
355
356 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward, // x
359 this->GetWeightsAt(0), // gamma (beta is not needed)
361 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
362 this->GetEpsilon(), descr->HelperDescriptor);
363 }
364}
365
366//______________________________________________________________________________
367template <typename Architecture_t>
369{
370 std::cout << " BATCH NORM Layer: \t";
371 std::cout << " Input/Output = ( " ;
372 auto &shape = this->GetOutput().GetShape();
373 for (size_t i = 0; i < shape.size(); ++i) {
374 if (i > 0) std::cout << " , ";
375 std::cout << shape[i];
376 }
377 std::cout << " ) ";
378 std::cout << "\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
379 std::cout << "\t axis = " << fNormAxis << std::endl;
380 std::cout << std::endl;
381}
382
383//______________________________________________________________________________
384
385template <typename Architecture_t>
387{
388
389 // write layer width activation function + weigbht and bias matrices
390
391 auto layerxml = gTools().xmlengine().NewChild(parent, 0, "BatchNormLayer");
392
393
396
397 // write stored mean and variances
398 //using Scalar_t = typename Architecture_t::Scalar_t;
399
400 this->WriteMatrixToXML(layerxml, "Training-mu", this->GetMuVector());
401 this->WriteMatrixToXML(layerxml, "Training-variance", this->GetVarVector());
402
403 // write weights (gamma and beta)
404 this->WriteMatrixToXML(layerxml, "Gamma", this->GetWeightsAt(0));
405 this->WriteMatrixToXML(layerxml, "Beta", this->GetWeightsAt(1));
406
407}
408
409//______________________________________________________________________________
410template <typename Architecture_t>
412{
413 // momentum and epsilon can be added after constructing the class
416 // Read layer weights and biases from XML
417
420
423}
424
425} // namespace DNN
426} // namespace TMVA
427
428#endif
