Logo ROOT   6.14/05
Reference Guide
RecurrentPropagation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Saurav Shekhar 23/06/17
3 
4 /*************************************************************************
5  * Copyright (C) 2017, Saurav Shekhar *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /////////////////////////////////////////////////////////////////////
13 // Implementation of the functions required for the forward and //
14 // backward propagation of activations through a recurrent neural //
15 // network in the reference implementation. //
16 /////////////////////////////////////////////////////////////////////
17 
19 
20 namespace TMVA {
21 namespace DNN {
22 
23 
24 //______________________________________________________________________________
25 template<typename Scalar_t>
27  TMatrixT<Scalar_t> & input_weight_gradients,
28  TMatrixT<Scalar_t> & state_weight_gradients,
29  TMatrixT<Scalar_t> & bias_gradients,
30  TMatrixT<Scalar_t> & df, //BxH
31  const TMatrixT<Scalar_t> & state, // BxH
32  const TMatrixT<Scalar_t> & weights_input, // HxD
33  const TMatrixT<Scalar_t> & weights_state, // HxH
34  const TMatrixT<Scalar_t> & input, // BxD
35  TMatrixT<Scalar_t> & input_gradient)
36 -> Matrix_t &
37 {
38 
39  // std::cout << "Reference Recurrent Propo" << std::endl;
40  // std::cout << "df\n";
41  // df.Print();
42  // std::cout << "state gradient\n";
43  // state_gradients_backward.Print();
44  // std::cout << "inputw gradient\n";
45  // input_weight_gradients.Print();
46  // std::cout << "state\n";
47  // state.Print();
48  // std::cout << "input\n";
49  // input.Print();
50 
51  // Compute element-wise product.
52  for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
53  for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
54  df(i,j) *= state_gradients_backward(i,j); // B x H
55  }
56  }
57 
58  // Input gradients.
59  if (input_gradient.GetNoElements() > 0) {
60  input_gradient.Mult(df, weights_input); // B x H . H x D = B x D
61  }
62  // State gradients
63  if (state_gradients_backward.GetNoElements() > 0) {
64  state_gradients_backward.Mult(df, weights_state); // B x H . H x H = B x H
65  }
66 
67  // Weights gradients.
68  if (input_weight_gradients.GetNoElements() > 0) {
69  TMatrixT<Scalar_t> tmp(input_weight_gradients);
70  input_weight_gradients.TMult(df, input); // H x B . B x D
71  input_weight_gradients += tmp;
72  }
73  if (state_weight_gradients.GetNoElements() > 0) {
74  TMatrixT<Scalar_t> tmp(state_weight_gradients);
75  state_weight_gradients.TMult(df, state); // H x B . B x H
76  state_weight_gradients += tmp;
77  }
78 
79  // Bias gradients. B x H -> H x 1
80  if (bias_gradients.GetNoElements() > 0) {
81  // this loops on state size
82  for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
83  Scalar_t sum = 0.0;
84  // this loops on batch size summing all gradient contributions in a batch
85  for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
86  sum += df(i,j);
87  }
88  bias_gradients(j,0) += sum;
89  }
90  }
91 
92  // std::cout << "RecurrentPropo: end " << std::endl;
93 
94  // std::cout << "state gradient\n";
95  // state_gradients_backward.Print();
96  // std::cout << "inputw gradient\n";
97  // input_weight_gradients.Print();
98  // std::cout << "bias gradient\n";
99  // bias_gradients.Print();
100  // std::cout << "input gradient\n";
101  // input_gradient.Print();
102 
103 
104  return input_gradient;
105 }
106 
107 
108 } // namespace DNN
109 } // namespace TMVA
static long int sum(long int i)
Definition: Factory.cxx:2258
static Matrix_t & RecurrentLayerBackward(TMatrixT< Scalar_t > &state_gradients_backward, TMatrixT< Scalar_t > &input_weight_gradients, TMatrixT< Scalar_t > &state_weight_gradients, TMatrixT< Scalar_t > &bias_gradients, TMatrixT< Scalar_t > &df, const TMatrixT< Scalar_t > &state, const TMatrixT< Scalar_t > &weights_input, const TMatrixT< Scalar_t > &weights_state, const TMatrixT< Scalar_t > &input, TMatrixT< Scalar_t > &input_gradient)
Backpropagation step for a Recurrent Neural Network.
TMatrixT.
Definition: TMatrixDfwd.h:22
Abstract ClassifierFactory template that handles arbitrary types.