27 #ifndef TMVA_DNN_BatchNormLayer
28 #define TMVA_DNN_BatchNormLayer
63 template <
typename Architecture_t>
67 using Scalar_t =
typename Architecture_t::Scalar_t;
68 using Matrix_t =
typename Architecture_t::Matrix_t;
69 using Tensor_t =
typename Architecture_t::Tensor_t;
101 TBatchNormLayer(
size_t batchSize,
size_t inputDepth,
size_t inputHeight,
size_t inputWidth,
186 std::vector<Matrix_t> params(2);
218 template <
typename Architecture_t>
220 size_t inputWidth,
const std::vector<size_t> &shape,
int axis,
222 :
VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth,
223 inputDepth, inputHeight, inputWidth,
225 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth),
227 shape[2], shape[0], shape[1],
229 fNormAxis(axis), fMomentum(momentum), fEpsilon(
epsilon),
230 fMu(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
231 fVar(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
232 fIVar(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
233 fMu_Training(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
234 fVar_Training(1,
VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
241 template <
typename Architecture_t>
246 printf(
"Error - copy ctor not implmented\n");
250 template <
typename Architecture_t>
254 printf(
"Error - copy ctor not implmented\n");
258 template <
typename Architecture_t>
263 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
268 template <
typename Architecture_t>
273 size_t bndim =
gamma.GetNcols();
276 for (
size_t i = 0; i < bndim; ++i) {
279 fMu_Training(0,i) = 0;
280 fVar_Training(0,i) = 1;
283 Matrix_t &dgamma = this->GetWeightGradientsAt(0);
284 Matrix_t &dbeta = this->GetWeightGradientsAt(1);
290 Architecture_t::InitializeBNormDescriptors(fDescriptors,
this);
294 template <
typename Architecture_t>
299 if (
x.GetLayout() != fReshapedData.GetLayout()) {
300 x2 =
Tensor_t(
x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
301 y2 =
Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
305 y2 = this->GetOutput();
310 Architecture_t::BatchNormLayerForwardTraining(fNormAxis,
x2, y2,
311 this->GetWeightsAt(0), this->GetWeightsAt(1),
312 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
314 this->GetVarVector(), this->GetNTrainedBatches(),
315 this->GetMomentum(), this->GetEpsilon(),
316 descr->HelperDescriptor);
327 Architecture_t::BatchNormLayerForwardInference(fNormAxis,
x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
328 y2, this->GetMuVector(), this->GetVarVector(),
329 this->GetEpsilon(), descr->HelperDescriptor);
336 template <
typename Architecture_t>
338 const Tensor_t & activations_backward ) ->
void
344 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
345 Tensor_t x =
Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
346 Tensor_t dx =
Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
347 Tensor_t dy =
Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
349 Architecture_t::BatchNormLayerBackward(fNormAxis,
x, dy, dx,
350 this->GetWeightsAt(0),
351 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
352 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
353 this->GetEpsilon(), descr->HelperDescriptor);
357 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward,
358 this->GetActivationGradients(),
360 this->GetWeightsAt(0),
361 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
362 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
363 this->GetEpsilon(), descr->HelperDescriptor);
368 template <
typename Architecture_t>
371 std::cout <<
" BATCH NORM Layer: \t";
372 std::cout <<
" Input/Output = ( " ;
373 auto &shape = this->GetOutput().GetShape();
374 for (
size_t i = 0; i < shape.size(); ++i) {
375 if (i > 0) std::cout <<
" , ";
376 std::cout << shape[i];
379 std::cout <<
"\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
380 std::cout <<
"\t axis = " << fNormAxis << std::endl;
381 std::cout << std::endl;
386 template <
typename Architecture_t>
401 this->WriteMatrixToXML(layerxml,
"Training-mu", this->GetMuVector());
402 this->WriteMatrixToXML(layerxml,
"Training-variance", this->GetVarVector());
405 this->WriteMatrixToXML(layerxml,
"Gamma", this->GetWeightsAt(0));
406 this->WriteMatrixToXML(layerxml,
"Beta", this->GetWeightsAt(1));
411 template <
typename Architecture_t>
419 this->ReadMatrixXML(parent,
"Training-mu", this->GetMuVector());
420 this->ReadMatrixXML(parent,
"Training-variance", this->GetVarVector());
422 this->ReadMatrixXML(parent,
"Gamma", this->GetWeightsAt(0));
423 this->ReadMatrixXML(parent,
"Beta", this->GetWeightsAt(1));