Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RecurrentPropagation.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Authors: Surya S Dwivedi 15/07/2019, Saurav Shekhar 23/06/17
3/*************************************************************************
4 * Copyright (C) 2019, Surya S Dwivedi, Saurav Shekhar *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11/////////////////////////////////////////////////////////////////////
12// Implementation of the functions required for the forward and //
13// backward propagation of activations through a recurrent neural //
14// network in the reference implementation. //
15/////////////////////////////////////////////////////////////////////
16
18
19namespace TMVA
20{
21namespace DNN
22{
23
24//______________________________________________________________________________
25template<typename Scalar_t>
30 TMatrixT<Scalar_t> & df, //BxH
31 const TMatrixT<Scalar_t> & state, // BxH
32 const TMatrixT<Scalar_t> & weights_input, // HxD
33 const TMatrixT<Scalar_t> & weights_state, // HxH
34 const TMatrixT<Scalar_t> & input, // BxD
36-> Matrix_t &
37{
38 // Compute element-wise product.
39 for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
40 for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
41 df(i,j) *= state_gradients_backward(i,j); // B x H
42 }
43 }
44
45 // Input gradients.
46 if (input_gradient.GetNoElements() > 0) {
47 input_gradient.Mult(df, weights_input); // B x H . H x D = B x D
48 }
49
50 // State gradients
51 if (state_gradients_backward.GetNoElements() > 0) {
52 state_gradients_backward.Mult(df, weights_state); // B x H . H x H = B x H
53 }
54
55 // Weights gradients.
56 if (input_weight_gradients.GetNoElements() > 0) {
58 input_weight_gradients.TMult(df, input); // H x B . B x D
60 }
61 if (state_weight_gradients.GetNoElements() > 0) {
63 state_weight_gradients.TMult(df, state); // H x B . B x H
65 }
66
67 // Bias gradients. B x H -> H x 1
68 if (bias_gradients.GetNoElements() > 0) {
69 // this loops on state size
70 for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
71 Scalar_t sum = 0.0;
72 // this loops on batch size summing all gradient contributions in a batch
73 for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
74 sum += df(i,j);
75 }
76 bias_gradients(j,0) += sum;
77 }
78 }
79
80 return input_gradient;
81}
82
83
84//______________________________________________________________________________
85template <typename Scalar_t>
106 const TMatrixT<Scalar_t> & fInput,
109 const TMatrixT<Scalar_t> & fOutput,
122-> Matrix_t &
123{
124 // cell gradient
125 Hadamard(cell_gradient, fOutput);
130
131 // candidate gradient
133 Hadamard(candidate_gradient, fInput);
134 Hadamard(candidate_gradient, dc);
135
136 // input gate gradient
139 Hadamard(input_gate_gradient, di);
140
141 // forget gradient
144 Hadamard(forget_gradient, df);
145
146 // output gradient
149 Hadamard(output_gradient, dout);
150
151 // input gradient
161
162 // state gradient backwards
172
173 //input weight gradients
186
187 // state weight gradients
200
201 // bias gradients
202 for (size_t j = 0; j < (size_t) df.GetNcols(); j++) {
203 Scalar_t sum_inp = 0.0, sum_forget = 0.0, sum_candidate = 0.0, sum_out = 0.0;
204 // this loops on batch size summing all gradient contributions in a batch
205 for (size_t i = 0; i < (size_t) df.GetNrows(); i++) {
210 }
215 }
216
217 return input_gradient;
218}
219
220
221
222//______________________________________________________________________________
223template <typename Scalar_t>
238 const TMatrixT<Scalar_t> & fReset,
239 const TMatrixT<Scalar_t> & fUpdate,
249-> Matrix_t &
250{
251 // reset gradient
253 for (size_t j = 0; j < (size_t) reset_gradient.GetNcols(); j++) {
254 for (size_t i = 0; i < (size_t) reset_gradient.GetNrows(); i++) {
255 reset_gradient(i,j) = 1 - reset_gradient(i,j);
256 }
257 }
258 Hadamard(reset_gradient, dc);
262 Hadamard(tmpMul, precStateActivations);
263 Hadamard(tmpMul, dr);
265
266 // update gradient
268 for (size_t j = 0; j < (size_t) update_gradient.GetNcols(); j++) {
269 for (size_t i = 0; i < (size_t) update_gradient.GetNrows(); i++) {
271 }
272 }
273 Hadamard(update_gradient, du);
275
276 // candidate gradient
278 for (size_t j = 0; j < (size_t) candidate_gradient.GetNcols(); j++) {
279 for (size_t i = 0; i < (size_t) candidate_gradient.GetNrows(); i++) {
281 }
282 }
283 Hadamard(candidate_gradient, dc);
285
286 // calculating state_gradient_backwards term by term
287 // term 1
289 TMatrixT<Scalar_t> term(fUpdate); // H X 1
290 Hadamard(term, temp);
292
293 //term 2
295 Hadamard(term, du);
296 Hadamard(term, temp);
298 var.Mult(term, weights_update_state);
299 term = var;
301
302 // term 3
304 for (size_t j = 0; j < (size_t) term.GetNcols(); j++) {
305 for (size_t i = 0; i < (size_t) term.GetNrows(); i++) {
306 term(i,j) = - term(i,j);
307 }
308 }
309 Hadamard(term, du);
310 Hadamard(term, temp);
311 var.Mult(term, weights_update_state);
312 term = var;
314
315 // term 4
316 term = fUpdate;
317 for (size_t j = 0; j < (size_t) term.GetNcols(); j++) {
318 for (size_t i = 0; i < (size_t) term.GetNrows(); i++) {
319 term(i,j) = 1 - term(i,j);
320 }
321 }
322 Hadamard(term, dc);
323 Hadamard(term, temp);
324 var.Mult(term, weights_candidate_state);
325 Hadamard(var, fReset);
326 term = var;
328
329 // term 5
330 term = fUpdate;
331 for (size_t j = 0; j < (size_t) term.GetNcols(); j++) {
332 for (size_t i = 0; i < (size_t) term.GetNrows(); i++) {
333 term(i,j) = 1 - term(i,j);
334 }
335 }
336 Hadamard(term, dc);
337 Hadamard(term, temp);
338 var.Mult(term, weights_candidate_state);
339 Hadamard(var, precStateActivations);
340 Hadamard(var, dr);
341 term.Mult(var, weights_reset_state);
343
344 // input gradients
352
353 //input weight gradients
363
364 // state weight gradients
372 TMatrixT<Scalar_t> tmp2(fReset);
373 Hadamard(tmp2, precStateActivations);
376
377 // bias gradients
378 for (size_t j = 0; j < (size_t) du.GetNcols(); j++) {
379 Scalar_t sum_reset = 0.0, sum_update = 0.0, sum_candidate = 0.0;
380 // this loops on batch size summing all gradient contributions in a batch
381 for (size_t i = 0; i < (size_t) du.GetNrows(); i++) {
385 }
389 }
390
391 return input_gradient;
392}
393
394} // namespace DNN
395} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
static Matrix_t & GRULayerBackward(TMatrixT< Scalar_t > &state_gradients_backward, TMatrixT< Scalar_t > &reset_weight_gradients, TMatrixT< Scalar_t > &update_weight_gradients, TMatrixT< Scalar_t > &candidate_weight_gradients, TMatrixT< Scalar_t > &reset_state_weight_gradients, TMatrixT< Scalar_t > &update_state_weight_gradients, TMatrixT< Scalar_t > &candidate_state_weight_gradients, TMatrixT< Scalar_t > &reset_bias_gradients, TMatrixT< Scalar_t > &update_bias_gradients, TMatrixT< Scalar_t > &candidate_bias_gradients, TMatrixT< Scalar_t > &dr, TMatrixT< Scalar_t > &du, TMatrixT< Scalar_t > &dc, const TMatrixT< Scalar_t > &precStateActivations, const TMatrixT< Scalar_t > &fReset, const TMatrixT< Scalar_t > &fUpdate, const TMatrixT< Scalar_t > &fCandidate, const TMatrixT< Scalar_t > &weights_reset, const TMatrixT< Scalar_t > &weights_update, const TMatrixT< Scalar_t > &weights_candidate, const TMatrixT< Scalar_t > &weights_reset_state, const TMatrixT< Scalar_t > &weights_update_state, const TMatrixT< Scalar_t > &weights_candidate_state, const TMatrixT< Scalar_t > &input, TMatrixT< Scalar_t > &input_gradient)
Backward pass for GRU Network.
static Matrix_t & LSTMLayerBackward(TMatrixT< Scalar_t > &state_gradients_backward, TMatrixT< Scalar_t > &cell_gradients_backward, TMatrixT< Scalar_t > &input_weight_gradients, TMatrixT< Scalar_t > &forget_weight_gradients, TMatrixT< Scalar_t > &candidate_weight_gradients, TMatrixT< Scalar_t > &output_weight_gradients, TMatrixT< Scalar_t > &input_state_weight_gradients, TMatrixT< Scalar_t > &forget_state_weight_gradients, TMatrixT< Scalar_t > &candidate_state_weight_gradients, TMatrixT< Scalar_t > &output_state_weight_gradients, TMatrixT< Scalar_t > &input_bias_gradients, TMatrixT< Scalar_t > &forget_bias_gradients, TMatrixT< Scalar_t > &candidate_bias_gradients, TMatrixT< Scalar_t > &output_bias_gradients, TMatrixT< Scalar_t > &di, TMatrixT< Scalar_t > &df, TMatrixT< Scalar_t > &dc, TMatrixT< Scalar_t > &dout, const TMatrixT< Scalar_t > &precStateActivations, const TMatrixT< Scalar_t > &precCellActivations, const TMatrixT< Scalar_t > &fInput, const TMatrixT< Scalar_t > &fForget, const TMatrixT< Scalar_t > &fCandidate, const TMatrixT< Scalar_t > &fOutput, const TMatrixT< Scalar_t > &weights_input, const TMatrixT< Scalar_t > &weights_forget, const TMatrixT< Scalar_t > &weights_candidate, const TMatrixT< Scalar_t > &weights_output, const TMatrixT< Scalar_t > &weights_input_state, const TMatrixT< Scalar_t > &weights_forget_state, const TMatrixT< Scalar_t > &weights_candidate_state, const TMatrixT< Scalar_t > &weights_output_state, const TMatrixT< Scalar_t > &input, TMatrixT< Scalar_t > &input_gradient, TMatrixT< Scalar_t > &cell_gradient, TMatrixT< Scalar_t > &cell_tanh)
Backward pass for LSTM Network.
static Matrix_t & RecurrentLayerBackward(TMatrixT< Scalar_t > &state_gradients_backward, TMatrixT< Scalar_t > &input_weight_gradients, TMatrixT< Scalar_t > &state_weight_gradients, TMatrixT< Scalar_t > &bias_gradients, TMatrixT< Scalar_t > &df, const TMatrixT< Scalar_t > &state, const TMatrixT< Scalar_t > &weights_input, const TMatrixT< Scalar_t > &weights_state, const TMatrixT< Scalar_t > &input, TMatrixT< Scalar_t > &input_gradient)
Backpropagation step for a Recurrent Neural Network.
create variable transformations
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2345