27#ifndef TMVA_DNN_BatchNormLayer
28#define TMVA_DNN_BatchNormLayer
62template <
typename Architecture_t>
66 using Scalar_t =
typename Architecture_t::Scalar_t;
67 using Matrix_t =
typename Architecture_t::Matrix_t;
68 using Tensor_t =
typename Architecture_t::Tensor_t;
100 TBatchNormLayer(
size_t batchSize,
size_t inputDepth,
size_t inputHeight,
size_t inputWidth,
185 std::vector<Matrix_t> params(2);
217template <
typename Architecture_t>
219 size_t inputWidth,
const std::vector<size_t> &shape,
int axis,
221 :
VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth,
222 inputDepth, inputHeight, inputWidth,
224 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth),
226 shape[2], shape[0], shape[1],
228 fNormAxis(axis), fMomentum(momentum), fEpsilon(
epsilon),
229 fMu(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
230 fVar(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
231 fIVar(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
232 fMu_Training(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
233 fVar_Training(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
240template <
typename Architecture_t>
245 printf(
"Error - copy ctor not implmented\n");
249template <
typename Architecture_t>
253 printf(
"Error - copy ctor not implmented\n");
257template <
typename Architecture_t>
262 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
267template <
typename Architecture_t>
272 size_t bndim =
gamma.GetNcols();
275 for (
size_t i = 0; i < bndim; ++i) {
278 fMu_Training(0,i) = 0;
279 fVar_Training(0,i) = 1;
282 Matrix_t &dgamma = this->GetWeightGradientsAt(0);
283 Matrix_t &dbeta = this->GetWeightGradientsAt(1);
289 Architecture_t::InitializeBNormDescriptors(fDescriptors,
this);
293template <
typename Architecture_t>
298 if (
x.GetLayout() != fReshapedData.GetLayout()) {
299 x2 =
Tensor_t(
x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
300 y2 =
Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
304 y2 = this->GetOutput();
309 Architecture_t::BatchNormLayerForwardTraining(fNormAxis,
x2, y2,
310 this->GetWeightsAt(0), this->GetWeightsAt(1),
311 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
313 this->GetVarVector(), this->GetNTrainedBatches(),
314 this->GetMomentum(), this->GetEpsilon(),
315 descr->HelperDescriptor);
326 Architecture_t::BatchNormLayerForwardInference(fNormAxis,
x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
327 y2, this->GetMuVector(), this->GetVarVector(),
328 this->GetEpsilon(), descr->HelperDescriptor);
335template <
typename Architecture_t>
337 const Tensor_t & activations_backward ) ->
void
343 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
344 Tensor_t x =
Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
345 Tensor_t dx =
Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
346 Tensor_t dy =
Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
348 Architecture_t::BatchNormLayerBackward(fNormAxis,
x, dy, dx,
349 this->GetWeightsAt(0),
350 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
351 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
352 this->GetEpsilon(), descr->HelperDescriptor);
356 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward,
357 this->GetActivationGradients(),
359 this->GetWeightsAt(0),
360 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
361 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
362 this->GetEpsilon(), descr->HelperDescriptor);
367template <
typename Architecture_t>
370 std::cout <<
" BATCH NORM Layer: \t";
371 std::cout <<
" Input/Output = ( " ;
372 auto &shape = this->GetOutput().GetShape();
373 for (
size_t i = 0; i < shape.size(); ++i) {
374 if (i > 0) std::cout <<
" , ";
375 std::cout << shape[i];
378 std::cout <<
"\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
379 std::cout <<
"\t axis = " << fNormAxis << std::endl;
380 std::cout << std::endl;
385template <
typename Architecture_t>
400 this->WriteMatrixToXML(layerxml,
"Training-mu", this->GetMuVector());
401 this->WriteMatrixToXML(layerxml,
"Training-variance", this->GetVarVector());
404 this->WriteMatrixToXML(layerxml,
"Gamma", this->GetWeightsAt(0));
405 this->WriteMatrixToXML(layerxml,
"Beta", this->GetWeightsAt(1));
410template <
typename Architecture_t>
418 this->ReadMatrixXML(parent,
"Training-mu", this->GetMuVector());
419 this->ReadMatrixXML(parent,
"Training-variance", this->GetVarVector());
421 this->ReadMatrixXML(parent,
"Gamma", this->GetWeightsAt(0));
422 this->ReadMatrixXML(parent,
"Beta", this->GetWeightsAt(1));
static const double x2[5]
Layer implementing Batch Normalization.
static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
const Matrix_t & GetMuVector() const
int fNormAxis
Normalization axis. For each element of this axis we will compute mean and stddev.
typename Architecture_t::Matrix_t Matrix_t
Scalar_t GetMomentum() const
TDescriptors * fDescriptors
int & GetNTrainedBatches()
Scalar_t GetEpsilon() const
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
Scalar_t GetNormAxis() const
void SetExtraLayerParameters(const std::vector< Matrix_t > ¶ms)
void ResetTraining()
Reset some training flags after a loop on all batches Some layer (e.g.
std::vector< Matrix_t > GetExtraLayerParameters() const
typename Architecture_t::Tensor_t Tensor_t
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Scalar_t fMomentum
The weight decay.
Matrix_t & GetVarVector()
const Matrix_t & GetVariance() const
Matrix_t & GetBatchMean()
const Matrix_t & GetReshapedData() const
void Print() const
Printing the layer info.
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
const int & GetNTrainedBatches() const
const Matrix_t & GetIVariance() const
typename Architecture_t::Scalar_t Scalar_t
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Compute weight, bias and activation gradients.
Matrix_t & GetReshapedData()
void Forward(Tensor_t &input, bool inTraining=true)
Compute activation of the layer for the given input.
~TBatchNormLayer()
Destructor.
typename Architecture_t::TensorDescriptor_t HelperDescriptor_t
typename Architecture_t::BNormDescriptors_t BNormDescriptors_t
const Matrix_t & GetVarVector() const
Tensor_t fDerivatives
First fDerivatives of the activations of this layer.
const Matrix_t & GetBatchMean() const
TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth, const std::vector< size_t > &shape, int axis=-1, Scalar_t momentum=-1., Scalar_t epsilon=0.0001)
Constructor.
Matrix_t & GetIVariance()
Generic General Layer class.
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
double beta(double x, double y)
Calculates the beta function.
create variable transformations