Logo ROOT   master
Reference Guide
RecurrentPropagation.hxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Authors: Surya S Dwivedi 01/08/2019, Saurav Shekhar 23/06/17
3 /*************************************************************************
4  * Copyright (C) 2019, Surya S Dwivedi, Saurav Shekhar *
5  * All rights reserved. *
6  * *
7  * For the licensing terms see $ROOTSYS/LICENSE. *
8  * For the list of contributors see $ROOTSYS/README/CREDITS. *
9  *************************************************************************/
10 
11 /////////////////////////////////////////////////////////////////////
12 // Implementation of the functions required for the forward and //
13 // backward propagation of activations through a recurrent neural //
14 // network in the TCpu architecture //
15 /////////////////////////////////////////////////////////////////////
16 
19 
20 namespace TMVA
21 {
22 namespace DNN
23 {
24 
25 template<typename AFloat>
26 auto TCpu<AFloat>::RecurrentLayerBackward(TCpuMatrix<AFloat> & state_gradients_backward, // BxH
27  TCpuMatrix<AFloat> & input_weight_gradients,
28  TCpuMatrix<AFloat> & state_weight_gradients,
29  TCpuMatrix<AFloat> & bias_gradients,
30  TCpuMatrix<AFloat> & df, //BxH
31  const TCpuMatrix<AFloat> & state, // BxH
32  const TCpuMatrix<AFloat> & weights_input, // HxD
33  const TCpuMatrix<AFloat> & weights_state, // HxH
34  const TCpuMatrix<AFloat> & input, // BxD
35  TCpuMatrix<AFloat> & input_gradient)
37 {
38  // std::cout << "Recurrent Propo" << std::endl;
39  // TMVA_DNN_PrintTCpuMatrix(df,"DF");
40  // TMVA_DNN_PrintTCpuMatrix(state_gradients_backward,"State grad");
41  // TMVA_DNN_PrintTCpuMatrix(input_weight_gradients,"input w grad");
42  // TMVA_DNN_PrintTCpuMatrix(state,"state");
43  // TMVA_DNN_PrintTCpuMatrix(input,"input");
44 
45  // Compute element-wise product.
46  //Hadamard(df, state_gradients_backward); // B x H
47 
48  // Input gradients.
49  if (input_gradient.GetNoElements() > 0) {
50  Multiply(input_gradient, df, weights_input);
51  }
52 
53  // State gradients.
54  if (state_gradients_backward.GetNoElements() > 0) {
55  Multiply(state_gradients_backward, df, weights_state);
56  }
57 
58  // compute the gradients
59  // Perform the operation in place by readding the result on the same gradient matrix
60  // e.g. W += D * X
61 
62  // Weights gradients
63  if (input_weight_gradients.GetNoElements() > 0) {
64  TransposeMultiply(input_weight_gradients, df, input, 1. , 1.); // H x B . B x D
65  }
66 
67  if (state_weight_gradients.GetNoElements() > 0) {
68  TransposeMultiply(state_weight_gradients, df, state, 1. , 1. ); // H x B . B x H
69  }
70 
71  // Bias gradients.
72  if (bias_gradients.GetNoElements() > 0) {
73  SumColumns(bias_gradients, df, 1., 1.); // could be probably do all here
74  }
75 
76  return input_gradient;
77 }
78 
79 //______________________________________________________________________________
80 template <typename Scalar_t>
81 auto inline TCpu<Scalar_t>::LSTMLayerBackward(TCpuMatrix<Scalar_t> & state_gradients_backward,
82  TCpuMatrix<Scalar_t> & cell_gradients_backward,
83  TCpuMatrix<Scalar_t> & input_weight_gradients,
84  TCpuMatrix<Scalar_t> & forget_weight_gradients,
85  TCpuMatrix<Scalar_t> & candidate_weight_gradients,
86  TCpuMatrix<Scalar_t> & output_weight_gradients,
87  TCpuMatrix<Scalar_t> & input_state_weight_gradients,
88  TCpuMatrix<Scalar_t> & forget_state_weight_gradients,
89  TCpuMatrix<Scalar_t> & candidate_state_weight_gradients,
90  TCpuMatrix<Scalar_t> & output_state_weight_gradients,
91  TCpuMatrix<Scalar_t> & input_bias_gradients,
92  TCpuMatrix<Scalar_t> & forget_bias_gradients,
93  TCpuMatrix<Scalar_t> & candidate_bias_gradients,
94  TCpuMatrix<Scalar_t> & output_bias_gradients,
98  TCpuMatrix<Scalar_t> & dout,
99  const TCpuMatrix<Scalar_t> & precStateActivations,
100  const TCpuMatrix<Scalar_t> & precCellActivations,
101  const TCpuMatrix<Scalar_t> & fInput,
102  const TCpuMatrix<Scalar_t> & fForget,
103  const TCpuMatrix<Scalar_t> & fCandidate,
104  const TCpuMatrix<Scalar_t> & fOutput,
105  const TCpuMatrix<Scalar_t> & weights_input,
106  const TCpuMatrix<Scalar_t> & weights_forget,
107  const TCpuMatrix<Scalar_t> & weights_candidate,
108  const TCpuMatrix<Scalar_t> & weights_output,
109  const TCpuMatrix<Scalar_t> & weights_input_state,
110  const TCpuMatrix<Scalar_t> & weights_forget_state,
111  const TCpuMatrix<Scalar_t> & weights_candidate_state,
112  const TCpuMatrix<Scalar_t> & weights_output_state,
113  const TCpuMatrix<Scalar_t> & input,
114  TCpuMatrix<Scalar_t> & input_gradient,
115  TCpuMatrix<Scalar_t> & cell_gradient,
116  TCpuMatrix<Scalar_t> & cell_tanh)
118 {
119  //some temporary varibales used later
120  TCpuMatrix<Scalar_t> tmpInp(input_gradient.GetNrows(), input_gradient.GetNcols());
121  TCpuMatrix<Scalar_t> tmpState(state_gradients_backward.GetNrows(), state_gradients_backward.GetNcols());
122 
123  TCpuMatrix<Scalar_t> input_gate_gradient(fInput.GetNrows(), fInput.GetNcols());
124  TCpuMatrix<Scalar_t> forget_gradient(fForget.GetNrows(), fForget.GetNcols());
125  TCpuMatrix<Scalar_t> candidate_gradient(fCandidate.GetNrows(), fCandidate.GetNcols());
126  TCpuMatrix<Scalar_t> output_gradient(fOutput.GetNrows(), fOutput.GetNcols());
127 
128  // cell gradient
129  Hadamard(cell_gradient, fOutput);
130  Hadamard(cell_gradient, state_gradients_backward);
131  ScaleAdd(cell_gradient, cell_gradients_backward);
132  Copy(cell_gradients_backward, cell_gradient);
133  Hadamard(cell_gradients_backward, fForget);
134 
135  // candidate gradient
136  Copy(candidate_gradient, cell_gradient);
137  Hadamard(candidate_gradient, fInput);
138  Hadamard(candidate_gradient, dc);
139 
140  // input gate gradient
141  Copy(input_gate_gradient, cell_gradient);
142  Hadamard(input_gate_gradient, fCandidate);
143  Hadamard(input_gate_gradient, di);
144 
145  // forget gradient
146  Copy(forget_gradient, cell_gradient);
147  Hadamard(forget_gradient, precCellActivations);
148  Hadamard(forget_gradient, df);
149 
150  // output grdient
151  Copy(output_gradient, cell_tanh);
152  Hadamard(output_gradient, state_gradients_backward);
153  Hadamard(output_gradient, dout);
154 
155  // input gradient
156  Multiply(tmpInp, input_gate_gradient, weights_input);
157  Copy(input_gradient, tmpInp);
158  Multiply(tmpInp, forget_gradient, weights_forget);
159  ScaleAdd(input_gradient, tmpInp);
160  Multiply(tmpInp, candidate_gradient, weights_candidate);
161  ScaleAdd(input_gradient, tmpInp);
162  Multiply(tmpInp, output_gradient, weights_output);
163  ScaleAdd(input_gradient, tmpInp);
164 
165  // state gradient backwards
166  Multiply(tmpState, input_gate_gradient, weights_input_state);
167  Copy(state_gradients_backward, tmpState);
168  Multiply(tmpState, forget_gradient, weights_forget_state);
169  ScaleAdd(state_gradients_backward, tmpState);
170  Multiply(tmpState, candidate_gradient, weights_candidate_state);
171  ScaleAdd(state_gradients_backward, tmpState);
172  Multiply(tmpState, output_gradient, weights_output_state);
173  ScaleAdd(state_gradients_backward, tmpState);
174 
175  // input weight gradient
176  TransposeMultiply(input_weight_gradients, input_gate_gradient, input, 1. , 1.); // H x B . B x D
177  TransposeMultiply(forget_weight_gradients, forget_gradient, input, 1. , 1.);
178  TransposeMultiply(candidate_weight_gradients, candidate_gradient, input, 1. , 1.);
179  TransposeMultiply(output_weight_gradients, output_gradient, input, 1. , 1.);
180 
181  // state weight gradients
182  TransposeMultiply(input_state_weight_gradients, input_gate_gradient, precStateActivations, 1. , 1. ); // H x B . B x H
183  TransposeMultiply(forget_state_weight_gradients, forget_gradient, precStateActivations, 1. , 1. );
184  TransposeMultiply(candidate_state_weight_gradients, candidate_gradient, precStateActivations, 1. , 1. );
185  TransposeMultiply(output_state_weight_gradients, output_gradient, precStateActivations, 1. , 1. );
186 
187  // bias gradient
188  SumColumns(input_bias_gradients, input_gate_gradient, 1., 1.);
189  SumColumns(forget_bias_gradients, forget_gradient, 1., 1.);
190  SumColumns(candidate_bias_gradients, candidate_gradient, 1., 1.);
191  SumColumns(output_bias_gradients, output_gradient, 1., 1.);
192 
193  return input_gradient;
194 }
195 
196 
197 //______________________________________________________________________________
198 template <typename Scalar_t>
199 auto inline TCpu<Scalar_t>::GRULayerBackward(TCpuMatrix<Scalar_t> & state_gradients_backward,
200  TCpuMatrix<Scalar_t> & reset_weight_gradients,
201  TCpuMatrix<Scalar_t> & update_weight_gradients,
202  TCpuMatrix<Scalar_t> & candidate_weight_gradients,
203  TCpuMatrix<Scalar_t> & reset_state_weight_gradients,
204  TCpuMatrix<Scalar_t> & update_state_weight_gradients,
205  TCpuMatrix<Scalar_t> & candidate_state_weight_gradients,
206  TCpuMatrix<Scalar_t> & reset_bias_gradients,
207  TCpuMatrix<Scalar_t> & update_bias_gradients,
208  TCpuMatrix<Scalar_t> & candidate_bias_gradients,
212  const TCpuMatrix<Scalar_t> & precStateActivations,
213  const TCpuMatrix<Scalar_t> & fReset,
214  const TCpuMatrix<Scalar_t> & fUpdate,
215  const TCpuMatrix<Scalar_t> & fCandidate,
216  const TCpuMatrix<Scalar_t> & weights_reset,
217  const TCpuMatrix<Scalar_t> & weights_update,
218  const TCpuMatrix<Scalar_t> & weights_candidate,
219  const TCpuMatrix<Scalar_t> & weights_reset_state,
220  const TCpuMatrix<Scalar_t> & weights_update_state,
221  const TCpuMatrix<Scalar_t> & weights_candidate_state,
222  const TCpuMatrix<Scalar_t> & input,
223  TCpuMatrix<Scalar_t> & input_gradient,
224  bool resetGateAfter)
226 {
227  // reset gradient
228  int r = fUpdate.GetNrows(), c = fUpdate.GetNcols();
229  TCpuMatrix<Scalar_t> reset_gradient(r, c);
230  Copy(reset_gradient, fUpdate);
231  for (size_t j = 0; j < (size_t)reset_gradient.GetNcols(); j++) {
232  for (size_t i = 0; i < (size_t)reset_gradient.GetNrows(); i++) {
233  reset_gradient(i, j) = 1 - reset_gradient(i, j);
234  }
235  }
236  Hadamard(reset_gradient, dc);
237  Hadamard(reset_gradient, state_gradients_backward);
238  TCpuMatrix<Scalar_t> tmpMul(r, c);
239 
240  if (!resetGateAfter) {
241  // case resetGateAfter is false U * ( r * h)
242  // dr = h * (UT * dy)
243  Multiply(tmpMul, reset_gradient, weights_candidate_state);
244  Hadamard(tmpMul, precStateActivations);
245  } else {
246  // case true : r * ( U * h) --> dr = dy * (U * h)
247  MultiplyTranspose(tmpMul, precStateActivations, weights_candidate_state);
248  Hadamard(tmpMul, reset_gradient);
249  }
250  Hadamard(tmpMul, dr);
251  Copy(reset_gradient, tmpMul);
252 
253  // update gradient
254  TCpuMatrix<Scalar_t> update_gradient(r, c); // H X 1
255  Copy(update_gradient, precStateActivations);
256  for (size_t j = 0; j < (size_t)update_gradient.GetNcols(); j++) {
257  for (size_t i = 0; i < (size_t)update_gradient.GetNrows(); i++) {
258  update_gradient(i, j) = update_gradient(i, j) - fCandidate(i, j);
259  }
260  }
261  Hadamard(update_gradient, du);
262  Hadamard(update_gradient, state_gradients_backward);
263 
264  // candidate gradient
265  TCpuMatrix<Scalar_t> candidate_gradient(r, c);
266  Copy(candidate_gradient, fUpdate);
267  for (size_t j = 0; j < (size_t)candidate_gradient.GetNcols(); j++) {
268  for (size_t i = 0; i < (size_t)candidate_gradient.GetNrows(); i++) {
269  candidate_gradient(i, j) = 1 - candidate_gradient(i, j);
270  }
271  }
272  Hadamard(candidate_gradient, dc);
273  Hadamard(candidate_gradient, state_gradients_backward);
274 
275  // calculating state gradient backwards term by term
276  // term 1
277  TCpuMatrix<Scalar_t> temp(r, c);
278  Copy(temp, state_gradients_backward);
279  TCpuMatrix<Scalar_t> term(r, c); // H X 1
280  Copy(term, fUpdate);
281  Hadamard(term, temp);
282  Copy(state_gradients_backward, term);
283 
284  // term 2
285  Copy(term, precStateActivations);
286  Hadamard(term, du);
287  Hadamard(term, temp);
288  TCpuMatrix<Scalar_t> var(r, c);
289  Multiply(var, term, weights_update_state);
290  Copy(term, var);
291  ScaleAdd(state_gradients_backward, term);
292 
293  // term 3
294  Copy(term, fCandidate);
295  for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
296  for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
297  term(i, j) = -term(i, j);
298  }
299  }
300  Hadamard(term, du);
301  Hadamard(term, temp);
302  Multiply(var, term, weights_update_state);
303  Copy(term, var);
304  ScaleAdd(state_gradients_backward, term);
305 
306  // term 4
307  Copy(term, fUpdate);
308  for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
309  for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
310  term(i, j) = 1 - term(i, j);
311  }
312  }
313  Hadamard(term, dc);
314  Hadamard(term, temp);
315 
316  if (!resetGateAfter) {
317  // case resetGateAfter is false : U * ( r * h)
318  // dh = r * (UT * dy)
319  Multiply(var, term, weights_candidate_state);
320  Hadamard(var, fReset);
321  } else {
322  // case resetGateAfter = true
323  // dh = UT * ( r * dy )
324  Hadamard(term, fReset);
325  Multiply(var, term, weights_candidate_state);
326  }
327  //
328  Copy(term, var);
329  ScaleAdd(state_gradients_backward, term);
330 
331  // term 5
332  Copy(term, fUpdate);
333  for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
334  for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
335  term(i, j) = 1 - term(i, j);
336  }
337  }
338  // here we re-compute dr (probably we could be more eficient)
339  Hadamard(term, dc);
340  Hadamard(term, temp);
341  if (!resetGateAfter) {
342  // case reset gate after = false
343  // recompute dr/dh (as above for dr): // dr = h * (UT * dy)
344  Multiply(var, term, weights_candidate_state);
345  Hadamard(var, precStateActivations);
346  } else {
347  // case = true dr = dy * (U * h)
348  MultiplyTranspose(var, precStateActivations, weights_candidate_state);
349  Hadamard(var, term);
350  }
351  Hadamard(var, dr);
352  Multiply(term, var, weights_reset_state);
353  ScaleAdd(state_gradients_backward, term);
354 
355  // input gradients
356  TCpuMatrix<Scalar_t> tmpInp(input_gradient.GetNrows(), input_gradient.GetNcols());
357  Multiply(tmpInp, reset_gradient, weights_reset);
358  Copy(input_gradient, tmpInp);
359  Multiply(tmpInp, update_gradient, weights_update);
360  ScaleAdd(input_gradient, tmpInp);
361  Multiply(tmpInp, candidate_gradient, weights_candidate);
362  ScaleAdd(input_gradient, tmpInp);
363 
364  // input weight gradients
365  TransposeMultiply(reset_weight_gradients, reset_gradient, input, 1., 1.); // H x B . B x D
366  TransposeMultiply(update_weight_gradients, update_gradient, input, 1., 1.);
367  TransposeMultiply(candidate_weight_gradients, candidate_gradient, input, 1., 1.);
368 
369  // state weight gradients
370  TransposeMultiply(reset_state_weight_gradients, reset_gradient, precStateActivations, 1., 1.); // H x B . B x H
371  TransposeMultiply(update_state_weight_gradients, update_gradient, precStateActivations, 1., 1.);
372  TCpuMatrix<Scalar_t> tempvar(r, c);
373 
374  // candidate weight gradients
375  // impl case reseyGateAfter = false
376  if (!resetGateAfter) {
377  // dU = ( h * r) * dy
378  Copy(tempvar, precStateActivations);
379  Hadamard(tempvar, fReset);
380  TransposeMultiply(candidate_state_weight_gradients, candidate_gradient, tempvar, 1., 1.);
381  } else {
382  // case resetAfter=true
383  // dU = h * ( r * dy)
384  Copy(tempvar, candidate_gradient);
385  Hadamard(tempvar, fReset);
386  TransposeMultiply(candidate_state_weight_gradients, tempvar, precStateActivations, 1., 1.);
387  }
388 
389  // bias gradients
390  SumColumns(reset_bias_gradients, reset_gradient, 1., 1.); // could be probably do all here
391  SumColumns(update_bias_gradients, update_gradient, 1., 1.);
392  SumColumns(candidate_bias_gradients, candidate_gradient, 1., 1.);
393 
394  return input_gradient;
395 }
396 
397 } // namespace DNN
398 } // namespace TMVA
The TCpuMatrix class.
Definition: CpuMatrix.h:86
static Matrix_t & RecurrentLayerBackward(Matrix_t &state_gradients_backward, Matrix_t &input_weight_gradients, Matrix_t &state_weight_gradients, Matrix_t &bias_gradients, Matrix_t &df, const Matrix_t &state, const Matrix_t &weights_input, const Matrix_t &weights_state, const Matrix_t &input, Matrix_t &input_gradient)
Backward pass for Recurrent Networks.
static Matrix_t & LSTMLayerBackward(TCpuMatrix< Scalar_t > &state_gradients_backward, TCpuMatrix< Scalar_t > &cell_gradients_backward, TCpuMatrix< Scalar_t > &input_weight_gradients, TCpuMatrix< Scalar_t > &forget_weight_gradients, TCpuMatrix< Scalar_t > &candidate_weight_gradients, TCpuMatrix< Scalar_t > &output_weight_gradients, TCpuMatrix< Scalar_t > &input_state_weight_gradients, TCpuMatrix< Scalar_t > &forget_state_weight_gradients, TCpuMatrix< Scalar_t > &candidate_state_weight_gradients, TCpuMatrix< Scalar_t > &output_state_weight_gradients, TCpuMatrix< Scalar_t > &input_bias_gradients, TCpuMatrix< Scalar_t > &forget_bias_gradients, TCpuMatrix< Scalar_t > &candidate_bias_gradients, TCpuMatrix< Scalar_t > &output_bias_gradients, TCpuMatrix< Scalar_t > &di, TCpuMatrix< Scalar_t > &df, TCpuMatrix< Scalar_t > &dc, TCpuMatrix< Scalar_t > &dout, const TCpuMatrix< Scalar_t > &precStateActivations, const TCpuMatrix< Scalar_t > &precCellActivations, const TCpuMatrix< Scalar_t > &fInput, const TCpuMatrix< Scalar_t > &fForget, const TCpuMatrix< Scalar_t > &fCandidate, const TCpuMatrix< Scalar_t > &fOutput, const TCpuMatrix< Scalar_t > &weights_input, const TCpuMatrix< Scalar_t > &weights_forget, const TCpuMatrix< Scalar_t > &weights_candidate, const TCpuMatrix< Scalar_t > &weights_output, const TCpuMatrix< Scalar_t > &weights_input_state, const TCpuMatrix< Scalar_t > &weights_forget_state, const TCpuMatrix< Scalar_t > &weights_candidate_state, const TCpuMatrix< Scalar_t > &weights_output_state, const TCpuMatrix< Scalar_t > &input, TCpuMatrix< Scalar_t > &input_gradient, TCpuMatrix< Scalar_t > &cell_gradient, TCpuMatrix< Scalar_t > &cell_tanh)
Backward pass for LSTM Network.
size_t GetNcols() const
Definition: CpuMatrix.h:156
static Matrix_t & GRULayerBackward(TCpuMatrix< Scalar_t > &state_gradients_backward, TCpuMatrix< Scalar_t > &reset_weight_gradients, TCpuMatrix< Scalar_t > &update_weight_gradients, TCpuMatrix< Scalar_t > &candidate_weight_gradients, TCpuMatrix< Scalar_t > &reset_state_weight_gradients, TCpuMatrix< Scalar_t > &update_state_weight_gradients, TCpuMatrix< Scalar_t > &candidate_state_weight_gradients, TCpuMatrix< Scalar_t > &reset_bias_gradients, TCpuMatrix< Scalar_t > &update_bias_gradients, TCpuMatrix< Scalar_t > &candidate_bias_gradients, TCpuMatrix< Scalar_t > &dr, TCpuMatrix< Scalar_t > &du, TCpuMatrix< Scalar_t > &dc, const TCpuMatrix< Scalar_t > &precStateActivations, const TCpuMatrix< Scalar_t > &fReset, const TCpuMatrix< Scalar_t > &fUpdate, const TCpuMatrix< Scalar_t > &fCandidate, const TCpuMatrix< Scalar_t > &weights_reset, const TCpuMatrix< Scalar_t > &weights_update, const TCpuMatrix< Scalar_t > &weights_candidate, const TCpuMatrix< Scalar_t > &weights_reset_state, const TCpuMatrix< Scalar_t > &weights_update_state, const TCpuMatrix< Scalar_t > &weights_candidate_state, const TCpuMatrix< Scalar_t > &input, TCpuMatrix< Scalar_t > &input_gradient, bool resetGateAfter)
Backward pass for GRU Network.
ROOT::R::TRInterface & r
Definition: Object.C:4
void Copy(void *source, void *dest)
create variable transformations
size_t GetNrows() const
Definition: CpuMatrix.h:155
#define c(i)
Definition: RSha256.hxx:101