Logo ROOT   6.14/05
Reference Guide
RecurrentPropagation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Saurav Shekhar 23/06/17
3 
4 /*************************************************************************
5  * Copyright (C) 2017, Saurav Shekhar *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /////////////////////////////////////////////////////////////////////
13 // Implementation of the functions required for the forward and //
14 // backward propagation of activations through a recurrent neural //
15 // network in the TCpu architecture //
16 /////////////////////////////////////////////////////////////////////
17 
18 
21 
22 namespace TMVA
23 {
24 namespace DNN
25 {
26 
27 template<typename AFloat>
28 auto TCpu<AFloat>::RecurrentLayerBackward(TCpuMatrix<AFloat> & state_gradients_backward, // BxH
29  TCpuMatrix<AFloat> & input_weight_gradients,
30  TCpuMatrix<AFloat> & state_weight_gradients,
31  TCpuMatrix<AFloat> & bias_gradients,
32  TCpuMatrix<AFloat> & df, //BxH
33  const TCpuMatrix<AFloat> & state, // BxH
34  const TCpuMatrix<AFloat> & weights_input, // HxD
35  const TCpuMatrix<AFloat> & weights_state, // HxH
36  const TCpuMatrix<AFloat> & input, // BxD
37  TCpuMatrix<AFloat> & input_gradient)
39 {
40 
41  // std::cout << "Recurrent Propo" << std::endl;
42  // PrintMatrix(df,"DF");
43  // PrintMatrix(state_gradients_backward,"State grad");
44  // PrintMatrix(input_weight_gradients,"input w grad");
45  // PrintMatrix(state,"state");
46  // PrintMatrix(input,"input");
47 
48  // Compute element-wise product.
49  Hadamard(df, state_gradients_backward); // B x H
50 
51  // Input gradients.
52  if (input_gradient.GetNElements() > 0) Multiply(input_gradient, df, weights_input);
53 
54  // State gradients.
55  if (state_gradients_backward.GetNElements() > 0) Multiply(state_gradients_backward, df, weights_state);
56 
57  // compute the gradients
58  // Perform the operation in place by readding the result on the same gradient matrix
59  // e.g. W += D * X
60 
61  // Weights gradients
62  if (input_weight_gradients.GetNElements() > 0) {
63  TransposeMultiply(input_weight_gradients, df, input, 1. , 1.); // H x B . B x D
64  }
65  if (state_weight_gradients.GetNElements() > 0) {
66  TransposeMultiply(state_weight_gradients, df, state, 1. , 1. ); // H x B . B x H
67  }
68 
69  // Bias gradients.
70  if (bias_gradients.GetNElements() > 0) {
71  SumColumns(bias_gradients, df, 1., 1.); // could be probably do all here
72  }
73 
74  //std::cout << "RecurrentPropo: end " << std::endl;
75 
76  // PrintMatrix(state_gradients_backward,"State grad");
77  // PrintMatrix(input_weight_gradients,"input w grad");
78  // PrintMatrix(bias_gradients,"bias grad");
79  // PrintMatrix(input_gradient,"input grad");
80 
81  return input_gradient;
82 }
83 
84 } // namespace DNN
85 } // namespace TMVA
The TCpuMatrix class.
Definition: CpuMatrix.h:72
static Matrix_t & RecurrentLayerBackward(TCpuMatrix< Scalar_t > &state_gradients_backward, TCpuMatrix< Scalar_t > &input_weight_gradients, TCpuMatrix< Scalar_t > &state_weight_gradients, TCpuMatrix< Scalar_t > &bias_gradients, TCpuMatrix< Scalar_t > &df, const TCpuMatrix< Scalar_t > &state, const TCpuMatrix< Scalar_t > &weights_input, const TCpuMatrix< Scalar_t > &weights_state, const TCpuMatrix< Scalar_t > &input, TCpuMatrix< Scalar_t > &input_gradient)
Backward pass for Recurrent Networks.
Abstract ClassifierFactory template that handles arbitrary types.