Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RecurrentPropagation.hxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Authors: Surya S Dwivedi 01/08/2019, Saurav Shekhar 23/06/17
3/*************************************************************************
4 * Copyright (C) 2019, Surya S Dwivedi, Saurav Shekhar *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11/////////////////////////////////////////////////////////////////////
12// Implementation of the functions required for the forward and //
13// backward propagation of activations through a recurrent neural //
14// network in the TCpu architecture //
15/////////////////////////////////////////////////////////////////////
16
18#include "Blas.h"
19
20namespace TMVA
21{
22namespace DNN
23{
24
25template<typename AFloat>
26auto TCpu<AFloat>::RecurrentLayerBackward(TCpuMatrix<AFloat> & state_gradients_backward, // BxH
27 TCpuMatrix<AFloat> & input_weight_gradients,
28 TCpuMatrix<AFloat> & state_weight_gradients,
29 TCpuMatrix<AFloat> & bias_gradients,
30 TCpuMatrix<AFloat> & df, //BxH
31 const TCpuMatrix<AFloat> & state, // BxH
32 const TCpuMatrix<AFloat> & weights_input, // HxD
33 const TCpuMatrix<AFloat> & weights_state, // HxH
34 const TCpuMatrix<AFloat> & input, // BxD
35 TCpuMatrix<AFloat> & input_gradient)
37{
38
39 // Compute element-wise product.
40 //Hadamard(df, state_gradients_backward); // B x H
41
42 // Input gradients.
43 if (input_gradient.GetNoElements() > 0) {
44 Multiply(input_gradient, df, weights_input);
45 }
46
47 // State gradients.
48 if (state_gradients_backward.GetNoElements() > 0) {
49 Multiply(state_gradients_backward, df, weights_state);
50 }
51
52 // compute the gradients
53 // Perform the operation in place by readding the result on the same gradient matrix
54 // e.g. W += D * X
55
56 // Weights gradients
57 if (input_weight_gradients.GetNoElements() > 0) {
58 TransposeMultiply(input_weight_gradients, df, input, 1. , 1.); // H x B . B x D
59 }
60
61 if (state_weight_gradients.GetNoElements() > 0) {
62 TransposeMultiply(state_weight_gradients, df, state, 1. , 1. ); // H x B . B x H
63 }
64
65 // Bias gradients.
66 if (bias_gradients.GetNoElements() > 0) {
67 SumColumns(bias_gradients, df, 1., 1.); // could be probably do all here
68 }
69
70 return input_gradient;
71}
72
73//______________________________________________________________________________
74template <typename Scalar_t>
75auto inline TCpu<Scalar_t>::LSTMLayerBackward(TCpuMatrix<Scalar_t> & state_gradients_backward,
76 TCpuMatrix<Scalar_t> & cell_gradients_backward,
77 TCpuMatrix<Scalar_t> & input_weight_gradients,
78 TCpuMatrix<Scalar_t> & forget_weight_gradients,
79 TCpuMatrix<Scalar_t> & candidate_weight_gradients,
80 TCpuMatrix<Scalar_t> & output_weight_gradients,
81 TCpuMatrix<Scalar_t> & input_state_weight_gradients,
82 TCpuMatrix<Scalar_t> & forget_state_weight_gradients,
83 TCpuMatrix<Scalar_t> & candidate_state_weight_gradients,
84 TCpuMatrix<Scalar_t> & output_state_weight_gradients,
85 TCpuMatrix<Scalar_t> & input_bias_gradients,
86 TCpuMatrix<Scalar_t> & forget_bias_gradients,
87 TCpuMatrix<Scalar_t> & candidate_bias_gradients,
88 TCpuMatrix<Scalar_t> & output_bias_gradients,
93 const TCpuMatrix<Scalar_t> & precStateActivations,
94 const TCpuMatrix<Scalar_t> & precCellActivations,
95 const TCpuMatrix<Scalar_t> & fInput,
96 const TCpuMatrix<Scalar_t> & fForget,
97 const TCpuMatrix<Scalar_t> & fCandidate,
98 const TCpuMatrix<Scalar_t> & fOutput,
99 const TCpuMatrix<Scalar_t> & weights_input,
100 const TCpuMatrix<Scalar_t> & weights_forget,
101 const TCpuMatrix<Scalar_t> & weights_candidate,
102 const TCpuMatrix<Scalar_t> & weights_output,
103 const TCpuMatrix<Scalar_t> & weights_input_state,
104 const TCpuMatrix<Scalar_t> & weights_forget_state,
105 const TCpuMatrix<Scalar_t> & weights_candidate_state,
106 const TCpuMatrix<Scalar_t> & weights_output_state,
108 TCpuMatrix<Scalar_t> & input_gradient,
109 TCpuMatrix<Scalar_t> & cell_gradient,
110 TCpuMatrix<Scalar_t> & cell_tanh)
112{
113 //some temporary varibales used later
114 TCpuMatrix<Scalar_t> tmpInp(input_gradient.GetNrows(), input_gradient.GetNcols());
115 TCpuMatrix<Scalar_t> tmpState(state_gradients_backward.GetNrows(), state_gradients_backward.GetNcols());
116
117 TCpuMatrix<Scalar_t> input_gate_gradient(fInput.GetNrows(), fInput.GetNcols());
118 TCpuMatrix<Scalar_t> forget_gradient(fForget.GetNrows(), fForget.GetNcols());
119 TCpuMatrix<Scalar_t> candidate_gradient(fCandidate.GetNrows(), fCandidate.GetNcols());
120 TCpuMatrix<Scalar_t> output_gradient(fOutput.GetNrows(), fOutput.GetNcols());
121
122 // cell gradient
123 Hadamard(cell_gradient, fOutput);
124 Hadamard(cell_gradient, state_gradients_backward);
125 ScaleAdd(cell_gradient, cell_gradients_backward);
126 Copy(cell_gradients_backward, cell_gradient);
127 Hadamard(cell_gradients_backward, fForget);
128
129 // candidate gradient
130 Copy(candidate_gradient, cell_gradient);
131 Hadamard(candidate_gradient, fInput);
132 Hadamard(candidate_gradient, dc);
133
134 // input gate gradient
135 Copy(input_gate_gradient, cell_gradient);
136 Hadamard(input_gate_gradient, fCandidate);
137 Hadamard(input_gate_gradient, di);
138
139 // forget gradient
140 Copy(forget_gradient, cell_gradient);
141 Hadamard(forget_gradient, precCellActivations);
142 Hadamard(forget_gradient, df);
143
144 // output grdient
145 Copy(output_gradient, cell_tanh);
146 Hadamard(output_gradient, state_gradients_backward);
147 Hadamard(output_gradient, dout);
148
149 // input gradient
150 Multiply(tmpInp, input_gate_gradient, weights_input);
151 Copy(input_gradient, tmpInp);
152 Multiply(tmpInp, forget_gradient, weights_forget);
153 ScaleAdd(input_gradient, tmpInp);
154 Multiply(tmpInp, candidate_gradient, weights_candidate);
155 ScaleAdd(input_gradient, tmpInp);
156 Multiply(tmpInp, output_gradient, weights_output);
157 ScaleAdd(input_gradient, tmpInp);
158
159 // state gradient backwards
160 Multiply(tmpState, input_gate_gradient, weights_input_state);
161 Copy(state_gradients_backward, tmpState);
162 Multiply(tmpState, forget_gradient, weights_forget_state);
163 ScaleAdd(state_gradients_backward, tmpState);
164 Multiply(tmpState, candidate_gradient, weights_candidate_state);
165 ScaleAdd(state_gradients_backward, tmpState);
166 Multiply(tmpState, output_gradient, weights_output_state);
167 ScaleAdd(state_gradients_backward, tmpState);
168
169 // input weight gradient
170 TransposeMultiply(input_weight_gradients, input_gate_gradient, input, 1. , 1.); // H x B . B x D
171 TransposeMultiply(forget_weight_gradients, forget_gradient, input, 1. , 1.);
172 TransposeMultiply(candidate_weight_gradients, candidate_gradient, input, 1. , 1.);
173 TransposeMultiply(output_weight_gradients, output_gradient, input, 1. , 1.);
174
175 // state weight gradients
176 TransposeMultiply(input_state_weight_gradients, input_gate_gradient, precStateActivations, 1. , 1. ); // H x B . B x H
177 TransposeMultiply(forget_state_weight_gradients, forget_gradient, precStateActivations, 1. , 1. );
178 TransposeMultiply(candidate_state_weight_gradients, candidate_gradient, precStateActivations, 1. , 1. );
179 TransposeMultiply(output_state_weight_gradients, output_gradient, precStateActivations, 1. , 1. );
180
181 // bias gradient
182 SumColumns(input_bias_gradients, input_gate_gradient, 1., 1.);
183 SumColumns(forget_bias_gradients, forget_gradient, 1., 1.);
184 SumColumns(candidate_bias_gradients, candidate_gradient, 1., 1.);
185 SumColumns(output_bias_gradients, output_gradient, 1., 1.);
186
187 return input_gradient;
188}
189
190
191//______________________________________________________________________________
192template <typename Scalar_t>
193auto inline TCpu<Scalar_t>::GRULayerBackward(TCpuMatrix<Scalar_t> & state_gradients_backward,
194 TCpuMatrix<Scalar_t> & reset_weight_gradients,
195 TCpuMatrix<Scalar_t> & update_weight_gradients,
196 TCpuMatrix<Scalar_t> & candidate_weight_gradients,
197 TCpuMatrix<Scalar_t> & reset_state_weight_gradients,
198 TCpuMatrix<Scalar_t> & update_state_weight_gradients,
199 TCpuMatrix<Scalar_t> & candidate_state_weight_gradients,
200 TCpuMatrix<Scalar_t> & reset_bias_gradients,
201 TCpuMatrix<Scalar_t> & update_bias_gradients,
202 TCpuMatrix<Scalar_t> & candidate_bias_gradients,
206 const TCpuMatrix<Scalar_t> & precStateActivations,
207 const TCpuMatrix<Scalar_t> & fReset,
208 const TCpuMatrix<Scalar_t> & fUpdate,
209 const TCpuMatrix<Scalar_t> & fCandidate,
210 const TCpuMatrix<Scalar_t> & weights_reset,
211 const TCpuMatrix<Scalar_t> & weights_update,
212 const TCpuMatrix<Scalar_t> & weights_candidate,
213 const TCpuMatrix<Scalar_t> & weights_reset_state,
214 const TCpuMatrix<Scalar_t> & weights_update_state,
215 const TCpuMatrix<Scalar_t> & weights_candidate_state,
217 TCpuMatrix<Scalar_t> & input_gradient,
218 bool resetGateAfter)
220{
221 // reset gradient
222 int r = fUpdate.GetNrows(), c = fUpdate.GetNcols();
223 TCpuMatrix<Scalar_t> reset_gradient(r, c);
224 Copy(reset_gradient, fUpdate);
225 for (size_t j = 0; j < (size_t)reset_gradient.GetNcols(); j++) {
226 for (size_t i = 0; i < (size_t)reset_gradient.GetNrows(); i++) {
227 reset_gradient(i, j) = 1 - reset_gradient(i, j);
228 }
229 }
230 Hadamard(reset_gradient, dc);
231 Hadamard(reset_gradient, state_gradients_backward);
232 TCpuMatrix<Scalar_t> tmpMul(r, c);
233
234 if (!resetGateAfter) {
235 // case resetGateAfter is false U * ( r * h)
236 // dr = h * (UT * dy)
237 Multiply(tmpMul, reset_gradient, weights_candidate_state);
238 Hadamard(tmpMul, precStateActivations);
239 } else {
240 // case true : r * ( U * h) --> dr = dy * (U * h)
241 MultiplyTranspose(tmpMul, precStateActivations, weights_candidate_state);
242 Hadamard(tmpMul, reset_gradient);
243 }
244 Hadamard(tmpMul, dr);
245 Copy(reset_gradient, tmpMul);
246
247 // update gradient
248 TCpuMatrix<Scalar_t> update_gradient(r, c); // H X 1
249 Copy(update_gradient, precStateActivations);
250 for (size_t j = 0; j < (size_t)update_gradient.GetNcols(); j++) {
251 for (size_t i = 0; i < (size_t)update_gradient.GetNrows(); i++) {
252 update_gradient(i, j) = update_gradient(i, j) - fCandidate(i, j);
253 }
254 }
255 Hadamard(update_gradient, du);
256 Hadamard(update_gradient, state_gradients_backward);
257
258 // candidate gradient
259 TCpuMatrix<Scalar_t> candidate_gradient(r, c);
260 Copy(candidate_gradient, fUpdate);
261 for (size_t j = 0; j < (size_t)candidate_gradient.GetNcols(); j++) {
262 for (size_t i = 0; i < (size_t)candidate_gradient.GetNrows(); i++) {
263 candidate_gradient(i, j) = 1 - candidate_gradient(i, j);
264 }
265 }
266 Hadamard(candidate_gradient, dc);
267 Hadamard(candidate_gradient, state_gradients_backward);
268
269 // calculating state gradient backwards term by term
270 // term 1
271 TCpuMatrix<Scalar_t> temp(r, c);
272 Copy(temp, state_gradients_backward);
273 TCpuMatrix<Scalar_t> term(r, c); // H X 1
274 Copy(term, fUpdate);
275 Hadamard(term, temp);
276 Copy(state_gradients_backward, term);
277
278 // term 2
279 Copy(term, precStateActivations);
280 Hadamard(term, du);
281 Hadamard(term, temp);
283 Multiply(var, term, weights_update_state);
284 Copy(term, var);
285 ScaleAdd(state_gradients_backward, term);
286
287 // term 3
288 Copy(term, fCandidate);
289 for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
290 for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
291 term(i, j) = -term(i, j);
292 }
293 }
294 Hadamard(term, du);
295 Hadamard(term, temp);
296 Multiply(var, term, weights_update_state);
297 Copy(term, var);
298 ScaleAdd(state_gradients_backward, term);
299
300 // term 4
301 Copy(term, fUpdate);
302 for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
303 for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
304 term(i, j) = 1 - term(i, j);
305 }
306 }
307 Hadamard(term, dc);
308 Hadamard(term, temp);
309
310 if (!resetGateAfter) {
311 // case resetGateAfter is false : U * ( r * h)
312 // dh = r * (UT * dy)
313 Multiply(var, term, weights_candidate_state);
314 Hadamard(var, fReset);
315 } else {
316 // case resetGateAfter = true
317 // dh = UT * ( r * dy )
318 Hadamard(term, fReset);
319 Multiply(var, term, weights_candidate_state);
320 }
321 //
322 Copy(term, var);
323 ScaleAdd(state_gradients_backward, term);
324
325 // term 5
326 Copy(term, fUpdate);
327 for (size_t j = 0; j < (size_t)term.GetNcols(); j++) {
328 for (size_t i = 0; i < (size_t)term.GetNrows(); i++) {
329 term(i, j) = 1 - term(i, j);
330 }
331 }
332 // here we re-compute dr (probably we could be more eficient)
333 Hadamard(term, dc);
334 Hadamard(term, temp);
335 if (!resetGateAfter) {
336 // case reset gate after = false
337 // recompute dr/dh (as above for dr): // dr = h * (UT * dy)
338 Multiply(var, term, weights_candidate_state);
339 Hadamard(var, precStateActivations);
340 } else {
341 // case = true dr = dy * (U * h)
342 MultiplyTranspose(var, precStateActivations, weights_candidate_state);
343 Hadamard(var, term);
344 }
345 Hadamard(var, dr);
346 Multiply(term, var, weights_reset_state);
347 ScaleAdd(state_gradients_backward, term);
348
349 // input gradients
350 TCpuMatrix<Scalar_t> tmpInp(input_gradient.GetNrows(), input_gradient.GetNcols());
351 Multiply(tmpInp, reset_gradient, weights_reset);
352 Copy(input_gradient, tmpInp);
353 Multiply(tmpInp, update_gradient, weights_update);
354 ScaleAdd(input_gradient, tmpInp);
355 Multiply(tmpInp, candidate_gradient, weights_candidate);
356 ScaleAdd(input_gradient, tmpInp);
357
358 // input weight gradients
359 TransposeMultiply(reset_weight_gradients, reset_gradient, input, 1., 1.); // H x B . B x D
360 TransposeMultiply(update_weight_gradients, update_gradient, input, 1., 1.);
361 TransposeMultiply(candidate_weight_gradients, candidate_gradient, input, 1., 1.);
362
363 // state weight gradients
364 TransposeMultiply(reset_state_weight_gradients, reset_gradient, precStateActivations, 1., 1.); // H x B . B x H
365 TransposeMultiply(update_state_weight_gradients, update_gradient, precStateActivations, 1., 1.);
366 TCpuMatrix<Scalar_t> tempvar(r, c);
367
368 // candidate weight gradients
369 // impl case reseyGateAfter = false
370 if (!resetGateAfter) {
371 // dU = ( h * r) * dy
372 Copy(tempvar, precStateActivations);
373 Hadamard(tempvar, fReset);
374 TransposeMultiply(candidate_state_weight_gradients, candidate_gradient, tempvar, 1., 1.);
375 } else {
376 // case resetAfter=true
377 // dU = h * ( r * dy)
378 Copy(tempvar, candidate_gradient);
379 Hadamard(tempvar, fReset);
380 TransposeMultiply(candidate_state_weight_gradients, tempvar, precStateActivations, 1., 1.);
381 }
382
383 // bias gradients
384 SumColumns(reset_bias_gradients, reset_gradient, 1., 1.); // could be probably do all here
385 SumColumns(update_bias_gradients, update_gradient, 1., 1.);
386 SumColumns(candidate_bias_gradients, candidate_gradient, 1., 1.);
387
388 return input_gradient;
389}
390
391} // namespace DNN
392} // namespace TMVA
#define c(i)
Definition RSha256.hxx:101
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
The TCpuMatrix class.
Definition CpuMatrix.h:86
size_t GetNcols() const
Definition CpuMatrix.h:156
size_t GetNrows() const
Definition CpuMatrix.h:155
static Matrix_t & LSTMLayerBackward(TCpuMatrix< Scalar_t > &state_gradients_backward, TCpuMatrix< Scalar_t > &cell_gradients_backward, TCpuMatrix< Scalar_t > &input_weight_gradients, TCpuMatrix< Scalar_t > &forget_weight_gradients, TCpuMatrix< Scalar_t > &candidate_weight_gradients, TCpuMatrix< Scalar_t > &output_weight_gradients, TCpuMatrix< Scalar_t > &input_state_weight_gradients, TCpuMatrix< Scalar_t > &forget_state_weight_gradients, TCpuMatrix< Scalar_t > &candidate_state_weight_gradients, TCpuMatrix< Scalar_t > &output_state_weight_gradients, TCpuMatrix< Scalar_t > &input_bias_gradients, TCpuMatrix< Scalar_t > &forget_bias_gradients, TCpuMatrix< Scalar_t > &candidate_bias_gradients, TCpuMatrix< Scalar_t > &output_bias_gradients, TCpuMatrix< Scalar_t > &di, TCpuMatrix< Scalar_t > &df, TCpuMatrix< Scalar_t > &dc, TCpuMatrix< Scalar_t > &dout, const TCpuMatrix< Scalar_t > &precStateActivations, const TCpuMatrix< Scalar_t > &precCellActivations, const TCpuMatrix< Scalar_t > &fInput, const TCpuMatrix< Scalar_t > &fForget, const TCpuMatrix< Scalar_t > &fCandidate, const TCpuMatrix< Scalar_t > &fOutput, const TCpuMatrix< Scalar_t > &weights_input, const TCpuMatrix< Scalar_t > &weights_forget, const TCpuMatrix< Scalar_t > &weights_candidate, const TCpuMatrix< Scalar_t > &weights_output, const TCpuMatrix< Scalar_t > &weights_input_state, const TCpuMatrix< Scalar_t > &weights_forget_state, const TCpuMatrix< Scalar_t > &weights_candidate_state, const TCpuMatrix< Scalar_t > &weights_output_state, const TCpuMatrix< Scalar_t > &input, TCpuMatrix< Scalar_t > &input_gradient, TCpuMatrix< Scalar_t > &cell_gradient, TCpuMatrix< Scalar_t > &cell_tanh)
Backward pass for LSTM Network.
static Matrix_t & RecurrentLayerBackward(Matrix_t &state_gradients_backward, Matrix_t &input_weight_gradients, Matrix_t &state_weight_gradients, Matrix_t &bias_gradients, Matrix_t &df, const Matrix_t &state, const Matrix_t &weights_input, const Matrix_t &weights_state, const Matrix_t &input, Matrix_t &input_gradient)
Backward pass for Recurrent Networks.
static Matrix_t & GRULayerBackward(TCpuMatrix< Scalar_t > &state_gradients_backward, TCpuMatrix< Scalar_t > &reset_weight_gradients, TCpuMatrix< Scalar_t > &update_weight_gradients, TCpuMatrix< Scalar_t > &candidate_weight_gradients, TCpuMatrix< Scalar_t > &reset_state_weight_gradients, TCpuMatrix< Scalar_t > &update_state_weight_gradients, TCpuMatrix< Scalar_t > &candidate_state_weight_gradients, TCpuMatrix< Scalar_t > &reset_bias_gradients, TCpuMatrix< Scalar_t > &update_bias_gradients, TCpuMatrix< Scalar_t > &candidate_bias_gradients, TCpuMatrix< Scalar_t > &dr, TCpuMatrix< Scalar_t > &du, TCpuMatrix< Scalar_t > &dc, const TCpuMatrix< Scalar_t > &precStateActivations, const TCpuMatrix< Scalar_t > &fReset, const TCpuMatrix< Scalar_t > &fUpdate, const TCpuMatrix< Scalar_t > &fCandidate, const TCpuMatrix< Scalar_t > &weights_reset, const TCpuMatrix< Scalar_t > &weights_update, const TCpuMatrix< Scalar_t > &weights_candidate, const TCpuMatrix< Scalar_t > &weights_reset_state, const TCpuMatrix< Scalar_t > &weights_update_state, const TCpuMatrix< Scalar_t > &weights_candidate_state, const TCpuMatrix< Scalar_t > &input, TCpuMatrix< Scalar_t > &input_gradient, bool resetGateAfter)
Backward pass for GRU Network.
create variable transformations