Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
LSTMLayer.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn/lstm:$Id$
2// Author: Surya S Dwivedi 27/05/19
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : BasicLSTMLayer *
8 * *
9 * Description: *
10 * NeuralNetwork *
11 * *
12 * Authors (alphabetical): *
13 * Surya S Dwivedi <surya2191997@gmail.com> - IIT Kharagpur, India *
14 * *
15 * Copyright (c) 2005-2019: *
16 * All rights reserved. *
17 * CERN, Switzerland *
18 * *
19 * For the licensing terms see $ROOTSYS/LICENSE. *
20 * For the list of contributors see $ROOTSYS/README/CREDITS. *
21 **********************************************************************************/
22
23//#pragma once
24
25//////////////////////////////////////////////////////////////////////
26// This class implements the LSTM layer. LSTM is a variant of vanilla
27// RNN which is capable of learning long range dependencies.
28//////////////////////////////////////////////////////////////////////
29
30#ifndef TMVA_DNN_LSTM_LAYER
31#define TMVA_DNN_LSTM_LAYER
32
33#include <cmath>
34#include <iostream>
35#include <vector>
36
37#include "TMatrix.h"
38#include "TMVA/DNN/Functions.h"
39
40namespace TMVA
41{
42namespace DNN
43{
44namespace RNN
45{
46
47//______________________________________________________________________________
48//
49// Basic LSTM Layer
50//______________________________________________________________________________
51
52/** \class BasicLSTMLayer
53 Generic implementation
54*/
55template<typename Architecture_t>
56 class TBasicLSTMLayer : public VGeneralLayer<Architecture_t>
57{
58
59public:
60
61 using Matrix_t = typename Architecture_t::Matrix_t;
62 using Scalar_t = typename Architecture_t::Scalar_t;
63 using Tensor_t = typename Architecture_t::Tensor_t;
64
65 using LayerDescriptor_t = typename Architecture_t::RecurrentDescriptor_t;
66 using WeightsDescriptor_t = typename Architecture_t::FilterDescriptor_t;
67 using TensorDescriptor_t = typename Architecture_t::TensorDescriptor_t;
68 using HelperDescriptor_t = typename Architecture_t::DropoutDescriptor_t;
69
70 using RNNWorkspace_t = typename Architecture_t::RNNWorkspace_t;
71 using RNNDescriptors_t = typename Architecture_t::RNNDescriptors_t;
72
73private:
74
75 size_t fStateSize; ///< Hidden state size for LSTM
76 size_t fCellSize; ///< Cell state size of LSTM
77 size_t fTimeSteps; ///< Timesteps for LSTM
78
79 bool fRememberState; ///< Remember state in next pass
80 bool fReturnSequence = false; ///< Return in output full sequence or just last element
81
82 DNN::EActivationFunction fF1; ///< Activation function: sigmoid
83 DNN::EActivationFunction fF2; ///< Activaton function: tanh
84
85 Matrix_t fInputValue; ///< Computed input gate values
86 Matrix_t fCandidateValue; ///< Computed candidate values
87 Matrix_t fForgetValue; ///< Computed forget gate values
88 Matrix_t fOutputValue; ///< Computed output gate values
89 Matrix_t fState; ///< Hidden state of LSTM
90 Matrix_t fCell; ///< Cell state of LSTM
91
92 Matrix_t &fWeightsInputGate; ///< Input Gate weights for input, fWeights[0]
93 Matrix_t &fWeightsInputGateState; ///< Input Gate weights for prev state, fWeights[1]
94 Matrix_t &fInputGateBias; ///< Input Gate bias
95
96 Matrix_t &fWeightsForgetGate; ///< Forget Gate weights for input, fWeights[2]
97 Matrix_t &fWeightsForgetGateState; ///< Forget Gate weights for prev state, fWeights[3]
98 Matrix_t &fForgetGateBias; ///< Forget Gate bias
99
100 Matrix_t &fWeightsCandidate; ///< Candidate Gate weights for input, fWeights[4]
101 Matrix_t &fWeightsCandidateState; ///< Candidate Gate weights for prev state, fWeights[5]
102 Matrix_t &fCandidateBias; ///< Candidate Gate bias
103
104 Matrix_t &fWeightsOutputGate; ///< Output Gate weights for input, fWeights[6]
105 Matrix_t &fWeightsOutputGateState; ///< Output Gate weights for prev state, fWeights[7]
106 Matrix_t &fOutputGateBias; ///< Output Gate bias
107
108 std::vector<Matrix_t> input_gate_value; ///< input gate value for every time step
109 std::vector<Matrix_t> forget_gate_value; ///< forget gate value for every time step
110 std::vector<Matrix_t> candidate_gate_value; ///< candidate gate value for every time step
111 std::vector<Matrix_t> output_gate_value; ///< output gate value for every time step
112 std::vector<Matrix_t> cell_value; ///< cell value for every time step
113 std::vector<Matrix_t> fDerivativesInput; ///< First fDerivatives of the activations input gate
114 std::vector<Matrix_t> fDerivativesForget; ///< First fDerivatives of the activations forget gate
115 std::vector<Matrix_t> fDerivativesCandidate; ///< First fDerivatives of the activations candidate gate
116 std::vector<Matrix_t> fDerivativesOutput; ///< First fDerivatives of the activations output gate
117
118 Matrix_t &fWeightsInputGradients; ///< Gradients w.r.t the input gate - input weights
119 Matrix_t &fWeightsInputStateGradients; ///< Gradients w.r.t the input gate - hidden state weights
120 Matrix_t &fInputBiasGradients; ///< Gradients w.r.t the input gate - bias weights
121 Matrix_t &fWeightsForgetGradients; ///< Gradients w.r.t the forget gate - input weights
122 Matrix_t &fWeightsForgetStateGradients; ///< Gradients w.r.t the forget gate - hidden state weights
123 Matrix_t &fForgetBiasGradients; ///< Gradients w.r.t the forget gate - bias weights
124 Matrix_t &fWeightsCandidateGradients; ///< Gradients w.r.t the candidate gate - input weights
125 Matrix_t &fWeightsCandidateStateGradients; ///< Gradients w.r.t the candidate gate - hidden state weights
126 Matrix_t &fCandidateBiasGradients; ///< Gradients w.r.t the candidate gate - bias weights
127 Matrix_t &fWeightsOutputGradients; ///< Gradients w.r.t the output gate - input weights
128 Matrix_t &fWeightsOutputStateGradients; ///< Gradients w.r.t the output gate - hidden state weights
129 Matrix_t &fOutputBiasGradients; ///< Gradients w.r.t the output gate - bias weights
130
131 // Tensor representing all weights (used by cuDNN)
132 Tensor_t fWeightsTensor; ///< Tensor for all weights
133 Tensor_t fWeightGradientsTensor; ///< Tensor for all weight gradients
134
135 // tensors used internally for the forward and backward pass
136 Tensor_t fX; ///< cached input tensor as T x B x I
137 Tensor_t fY; ///< cached output tensor as T x B x S
138 Tensor_t fDx; ///< cached gradient on the input (output of backward) as T x B x I
139 Tensor_t fDy; ///< cached activation gradient (input of backward) as T x B x S
140
141 TDescriptors *fDescriptors = nullptr; ///< Keeps all the RNN descriptors
142 TWorkspace *fWorkspace = nullptr; // workspace needed for GPU computation (CudNN)
143
144public:
145
146 /*! Constructor */
147 TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState = false,
148 bool returnSequence = false,
152
153 /*! Copy Constructor */
155
156 /*! Initialize the weights according to the given initialization
157 ** method. */
158 virtual void Initialize();
159
160 /*! Initialize the hidden state and cell state method. */
162
163 /*! Computes the next hidden state
164 * and next cell state with given input matrix. */
165 void Forward(Tensor_t &input, bool isTraining = true);
166
167 /*! Forward for a single cell (time unit) */
168 void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues,
169 const Matrix_t &candidateValues, const Matrix_t &outputGateValues);
170
171 /*! Backpropagates the error. Must only be called directly at the corresponding
172 * call to Forward(...). */
173 void Backward(Tensor_t &gradients_backward,
174 const Tensor_t &activations_backward);
175
176 /* Updates weights and biases, given the learning rate */
177 void Update(const Scalar_t learningRate);
178
179 /*! Backward for a single time unit
180 * a the corresponding call to Forward(...). */
181 Matrix_t & CellBackward(Matrix_t & state_gradients_backward,
182 Matrix_t & cell_gradients_backward,
183 const Matrix_t & precStateActivations, const Matrix_t & precCellActivations,
184 const Matrix_t & input_gate, const Matrix_t & forget_gate,
185 const Matrix_t & candidate_gate, const Matrix_t & output_gate,
186 const Matrix_t & input, Matrix_t & input_gradient,
187 Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t);
188
189 /*! Decides the values we'll update (NN with Sigmoid) */
190 void InputGate(const Matrix_t &input, Matrix_t &di);
191
192 /*! Forgets the past values (NN with Sigmoid) */
193 void ForgetGate(const Matrix_t &input, Matrix_t &df);
194
195 /*! Decides the new candidate values (NN with Tanh) */
196 void CandidateValue(const Matrix_t &input, Matrix_t &dc);
197
198 /*! Computes output values (NN with Sigmoid) */
199 void OutputGate(const Matrix_t &input, Matrix_t &dout);
200
201 /*! Prints the info about the layer */
202 void Print() const;
203
204 /*! Writes the information and the weights about the layer in an XML node. */
205 void AddWeightsXMLTo(void *parent);
206
207 /*! Read the information and the weights about the layer from XML node. */
208 void ReadWeightsFromXML(void *parent);
209
210 /*! Getters */
211 size_t GetInputSize() const { return this->GetInputWidth(); }
212 size_t GetTimeSteps() const { return fTimeSteps; }
213 size_t GetStateSize() const { return fStateSize; }
214 size_t GetCellSize() const { return fCellSize; }
215
216 inline bool DoesRememberState() const { return fRememberState; }
217 inline bool DoesReturnSequence() const { return fReturnSequence; }
218
221
222 const Matrix_t & GetInputGateValue() const { return fInputValue; }
224 const Matrix_t & GetCandidateValue() const { return fCandidateValue; }
226 const Matrix_t & GetForgetGateValue() const { return fForgetValue; }
228 const Matrix_t & GetOutputGateValue() const { return fOutputValue; }
230
231 const Matrix_t & GetState() const { return fState; }
232 Matrix_t & GetState() { return fState; }
233 const Matrix_t & GetCell() const { return fCell; }
234 Matrix_t & GetCell() { return fCell; }
235
252
253 const std::vector<Matrix_t> & GetDerivativesInput() const { return fDerivativesInput; }
254 std::vector<Matrix_t> & GetDerivativesInput() { return fDerivativesInput; }
255 const Matrix_t & GetInputDerivativesAt(size_t i) const { return fDerivativesInput[i]; }
257 const std::vector<Matrix_t> & GetDerivativesForget() const { return fDerivativesForget; }
258 std::vector<Matrix_t> & GetDerivativesForget() { return fDerivativesForget; }
259 const Matrix_t & GetForgetDerivativesAt(size_t i) const { return fDerivativesForget[i]; }
261 const std::vector<Matrix_t> & GetDerivativesCandidate() const { return fDerivativesCandidate; }
262 std::vector<Matrix_t> & GetDerivativesCandidate() { return fDerivativesCandidate; }
263 const Matrix_t & GetCandidateDerivativesAt(size_t i) const { return fDerivativesCandidate[i]; }
265 const std::vector<Matrix_t> & GetDerivativesOutput() const { return fDerivativesOutput; }
266 std::vector<Matrix_t> & GetDerivativesOutput() { return fDerivativesOutput; }
267 const Matrix_t & GetOutputDerivativesAt(size_t i) const { return fDerivativesOutput[i]; }
269
270 const std::vector<Matrix_t> & GetInputGateTensor() const { return input_gate_value; }
271 std::vector<Matrix_t> & GetInputGateTensor() { return input_gate_value; }
272 const Matrix_t & GetInputGateTensorAt(size_t i) const { return input_gate_value[i]; }
274 const std::vector<Matrix_t> & GetForgetGateTensor() const { return forget_gate_value; }
275 std::vector<Matrix_t> & GetForgetGateTensor() { return forget_gate_value; }
276 const Matrix_t & GetForgetGateTensorAt(size_t i) const { return forget_gate_value[i]; }
278 const std::vector<Matrix_t> & GetCandidateGateTensor() const { return candidate_gate_value; }
279 std::vector<Matrix_t> & GetCandidateGateTensor() { return candidate_gate_value; }
280 const Matrix_t & GetCandidateGateTensorAt(size_t i) const { return candidate_gate_value[i]; }
282 const std::vector<Matrix_t> & GetOutputGateTensor() const { return output_gate_value; }
283 std::vector<Matrix_t> & GetOutputGateTensor() { return output_gate_value; }
284 const Matrix_t & GetOutputGateTensorAt(size_t i) const { return output_gate_value[i]; }
286 const std::vector<Matrix_t> & GetCellTensor() const { return cell_value; }
287 std::vector<Matrix_t> & GetCellTensor() { return cell_value; }
288 const Matrix_t & GetCellTensorAt(size_t i) const { return cell_value[i]; }
289 Matrix_t & GetCellTensorAt(size_t i) { return cell_value[i]; }
290
291 const Matrix_t & GetInputGateBias() const { return fInputGateBias; }
293 const Matrix_t & GetForgetGateBias() const { return fForgetGateBias; }
295 const Matrix_t & GetCandidateBias() const { return fCandidateBias; }
297 const Matrix_t & GetOutputGateBias() const { return fOutputGateBias; }
323
325 const Tensor_t &GetWeightsTensor() const { return fWeightsTensor; }
328
329 Tensor_t &GetX() { return fX; }
330 Tensor_t &GetY() { return fY; }
331 Tensor_t &GetDX() { return fDx; }
332 Tensor_t &GetDY() { return fDy; }
333};
334
335//______________________________________________________________________________
336//
337// Basic LSTM-Layer Implementation
338//______________________________________________________________________________
339
340template <typename Architecture_t>
341TBasicLSTMLayer<Architecture_t>::TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps,
342 bool rememberState, bool returnSequence, DNN::EActivationFunction f1,
343 DNN::EActivationFunction f2, bool /* training */,
345 : VGeneralLayer<Architecture_t>(
346 batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize, 8,
347 {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
348 {inputSize, inputSize, inputSize, inputSize, stateSize, stateSize, stateSize, stateSize}, 4,
349 {stateSize, stateSize, stateSize, stateSize}, {1, 1, 1, 1}, batchSize, (returnSequence) ? timeSteps : 1,
350 stateSize, fA),
351 fStateSize(stateSize), fCellSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState),
352 fReturnSequence(returnSequence), fF1(f1), fF2(f2), fInputValue(batchSize, stateSize),
353 fCandidateValue(batchSize, stateSize), fForgetValue(batchSize, stateSize), fOutputValue(batchSize, stateSize),
354 fState(batchSize, stateSize), fCell(batchSize, stateSize), fWeightsInputGate(this->GetWeightsAt(0)),
355 fWeightsInputGateState(this->GetWeightsAt(4)), fInputGateBias(this->GetBiasesAt(0)),
356 fWeightsForgetGate(this->GetWeightsAt(1)), fWeightsForgetGateState(this->GetWeightsAt(5)),
357 fForgetGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
358 fWeightsCandidateState(this->GetWeightsAt(6)), fCandidateBias(this->GetBiasesAt(2)),
359 fWeightsOutputGate(this->GetWeightsAt(3)), fWeightsOutputGateState(this->GetWeightsAt(7)),
360 fOutputGateBias(this->GetBiasesAt(3)), fWeightsInputGradients(this->GetWeightGradientsAt(0)),
361 fWeightsInputStateGradients(this->GetWeightGradientsAt(4)), fInputBiasGradients(this->GetBiasGradientsAt(0)),
362 fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
363 fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)), fForgetBiasGradients(this->GetBiasGradientsAt(1)),
364 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
365 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
366 fCandidateBiasGradients(this->GetBiasGradientsAt(2)), fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
367 fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)), fOutputBiasGradients(this->GetBiasGradientsAt(3))
368{
369 for (size_t i = 0; i < timeSteps; ++i) {
370 fDerivativesInput.emplace_back(batchSize, stateSize);
371 fDerivativesForget.emplace_back(batchSize, stateSize);
372 fDerivativesCandidate.emplace_back(batchSize, stateSize);
373 fDerivativesOutput.emplace_back(batchSize, stateSize);
374 input_gate_value.emplace_back(batchSize, stateSize);
375 forget_gate_value.emplace_back(batchSize, stateSize);
376 candidate_gate_value.emplace_back(batchSize, stateSize);
377 output_gate_value.emplace_back(batchSize, stateSize);
378 cell_value.emplace_back(batchSize, stateSize);
379 }
380 Architecture_t::InitializeLSTMTensors(this);
381}
382
383 //______________________________________________________________________________
384template <typename Architecture_t>
386 : VGeneralLayer<Architecture_t>(layer),
387 fStateSize(layer.fStateSize),
388 fCellSize(layer.fCellSize),
389 fTimeSteps(layer.fTimeSteps),
390 fRememberState(layer.fRememberState),
391 fReturnSequence(layer.fReturnSequence),
392 fF1(layer.GetActivationFunctionF1()),
393 fF2(layer.GetActivationFunctionF2()),
394 fInputValue(layer.GetBatchSize(), layer.GetStateSize()),
395 fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
396 fForgetValue(layer.GetBatchSize(), layer.GetStateSize()),
397 fOutputValue(layer.GetBatchSize(), layer.GetStateSize()),
398 fState(layer.GetBatchSize(), layer.GetStateSize()),
399 fCell(layer.GetBatchSize(), layer.GetCellSize()),
400 fWeightsInputGate(this->GetWeightsAt(0)),
401 fWeightsInputGateState(this->GetWeightsAt(4)),
402 fInputGateBias(this->GetBiasesAt(0)),
403 fWeightsForgetGate(this->GetWeightsAt(1)),
404 fWeightsForgetGateState(this->GetWeightsAt(5)),
405 fForgetGateBias(this->GetBiasesAt(1)),
406 fWeightsCandidate(this->GetWeightsAt(2)),
407 fWeightsCandidateState(this->GetWeightsAt(6)),
408 fCandidateBias(this->GetBiasesAt(2)),
409 fWeightsOutputGate(this->GetWeightsAt(3)),
410 fWeightsOutputGateState(this->GetWeightsAt(7)),
411 fOutputGateBias(this->GetBiasesAt(3)),
412 fWeightsInputGradients(this->GetWeightGradientsAt(0)),
413 fWeightsInputStateGradients(this->GetWeightGradientsAt(4)),
414 fInputBiasGradients(this->GetBiasGradientsAt(0)),
415 fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
416 fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)),
417 fForgetBiasGradients(this->GetBiasGradientsAt(1)),
418 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
419 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
420 fCandidateBiasGradients(this->GetBiasGradientsAt(2)),
421 fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
422 fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)),
423 fOutputBiasGradients(this->GetBiasGradientsAt(3))
424{
425 for (size_t i = 0; i < fTimeSteps; ++i) {
426 fDerivativesInput.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
427 Architecture_t::Copy(fDerivativesInput[i], layer.GetInputDerivativesAt(i));
428
429 fDerivativesForget.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
430 Architecture_t::Copy(fDerivativesForget[i], layer.GetForgetDerivativesAt(i));
431
432 fDerivativesCandidate.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
433 Architecture_t::Copy(fDerivativesCandidate[i], layer.GetCandidateDerivativesAt(i));
434
435 fDerivativesOutput.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
436 Architecture_t::Copy(fDerivativesOutput[i], layer.GetOutputDerivativesAt(i));
437
438 input_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
439 Architecture_t::Copy(input_gate_value[i], layer.GetInputGateTensorAt(i));
440
441 forget_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
442 Architecture_t::Copy(forget_gate_value[i], layer.GetForgetGateTensorAt(i));
443
444 candidate_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
445 Architecture_t::Copy(candidate_gate_value[i], layer.GetCandidateGateTensorAt(i));
446
447 output_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
448 Architecture_t::Copy(output_gate_value[i], layer.GetOutputGateTensorAt(i));
449
450 cell_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
451 Architecture_t::Copy(cell_value[i], layer.GetCellTensorAt(i));
452 }
453
454 // Gradient matrices not copied
455 Architecture_t::Copy(fState, layer.GetState());
456 Architecture_t::Copy(fCell, layer.GetCell());
457
458 // Copy each gate values.
459 Architecture_t::Copy(fInputValue, layer.GetInputGateValue());
460 Architecture_t::Copy(fCandidateValue, layer.GetCandidateValue());
461 Architecture_t::Copy(fForgetValue, layer.GetForgetGateValue());
462 Architecture_t::Copy(fOutputValue, layer.GetOutputGateValue());
463
464 Architecture_t::InitializeLSTMTensors(this);
465}
466
467//______________________________________________________________________________
468template <typename Architecture_t>
470{
472
473 Architecture_t::InitializeLSTMDescriptors(fDescriptors, this);
474 Architecture_t::InitializeLSTMWorkspace(fWorkspace, fDescriptors, this);
475}
476
477//______________________________________________________________________________
478template <typename Architecture_t>
480-> void
481{
482 /*! Computes input gate values according to equation:
483 * input = act(W_input . input + W_state . state + bias)
484 * activation function: sigmoid. */
485 const DNN::EActivationFunction fInp = this->GetActivationFunctionF1();
486 Matrix_t tmpState(fInputValue.GetNrows(), fInputValue.GetNcols());
487 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsInputGateState);
488 Architecture_t::MultiplyTranspose(fInputValue, input, fWeightsInputGate);
489 Architecture_t::ScaleAdd(fInputValue, tmpState);
490 Architecture_t::AddRowWise(fInputValue, fInputGateBias);
491 DNN::evaluateDerivativeMatrix<Architecture_t>(di, fInp, fInputValue);
492 DNN::evaluateMatrix<Architecture_t>(fInputValue, fInp);
493}
494
495 //______________________________________________________________________________
496template <typename Architecture_t>
498-> void
499{
500 /*! Computes forget gate values according to equation:
501 * forget = act(W_input . input + W_state . state + bias)
502 * activation function: sigmoid. */
503 const DNN::EActivationFunction fFor = this->GetActivationFunctionF1();
504 Matrix_t tmpState(fForgetValue.GetNrows(), fForgetValue.GetNcols());
505 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsForgetGateState);
506 Architecture_t::MultiplyTranspose(fForgetValue, input, fWeightsForgetGate);
507 Architecture_t::ScaleAdd(fForgetValue, tmpState);
508 Architecture_t::AddRowWise(fForgetValue, fForgetGateBias);
509 DNN::evaluateDerivativeMatrix<Architecture_t>(df, fFor, fForgetValue);
510 DNN::evaluateMatrix<Architecture_t>(fForgetValue, fFor);
511}
512
513 //______________________________________________________________________________
514template <typename Architecture_t>
516-> void
517{
518 /*! Candidate value will be used to scale input gate values followed by Hadamard product.
519 * candidate_value = act(W_input . input + W_state . state + bias)
520 * activation function = tanh. */
521 const DNN::EActivationFunction fCan = this->GetActivationFunctionF2();
522 Matrix_t tmpState(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
523 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsCandidateState);
524 Architecture_t::MultiplyTranspose(fCandidateValue, input, fWeightsCandidate);
525 Architecture_t::ScaleAdd(fCandidateValue, tmpState);
526 Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
527 DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
528 DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
529}
530
531 //______________________________________________________________________________
532template <typename Architecture_t>
534-> void
535{
536 /*! Output gate values will be used to calculate next hidden state and output values.
537 * output = act(W_input . input + W_state . state + bias)
538 * activation function = sigmoid. */
539 const DNN::EActivationFunction fOut = this->GetActivationFunctionF1();
540 Matrix_t tmpState(fOutputValue.GetNrows(), fOutputValue.GetNcols());
541 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsOutputGateState);
542 Architecture_t::MultiplyTranspose(fOutputValue, input, fWeightsOutputGate);
543 Architecture_t::ScaleAdd(fOutputValue, tmpState);
544 Architecture_t::AddRowWise(fOutputValue, fOutputGateBias);
545 DNN::evaluateDerivativeMatrix<Architecture_t>(dout, fOut, fOutputValue);
546 DNN::evaluateMatrix<Architecture_t>(fOutputValue, fOut);
547}
548
549
550
551 //______________________________________________________________________________
552template <typename Architecture_t>
553auto inline TBasicLSTMLayer<Architecture_t>::Forward(Tensor_t &input, bool isTraining )
554-> void
555{
556
557 // for Cudnn
558 if (Architecture_t::IsCudnn()) {
559
560 // input size is stride[1] of input tensor that is B x T x inputSize
561 assert(input.GetStrides()[1] == this->GetInputSize());
562
563 Tensor_t &x = this->fX;
564 Tensor_t &y = this->fY;
565 Architecture_t::Rearrange(x, input);
566
567 const auto &weights = this->GetWeightsAt(0);
568 // Tensor_t cx({1}); // not used for normal RNN
569 // Tensor_t cy({1}); // not used for normal RNN
570
571 // hx is fState - tensor are of right shape
572 auto &hx = this->fState;
573 //auto &cx = this->fCell;
574 auto &cx = this->fCell; // pass an empty cell state
575 // use same for hy and cy
576 auto &hy = this->fState;
577 auto &cy = this->fCell;
578
579 auto rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
580 auto rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
581
582 Architecture_t::RNNForward(x, hx, cx, weights, y, hy, cy, rnnDesc, rnnWork, isTraining);
583
584 if (fReturnSequence) {
585 Architecture_t::Rearrange(this->GetOutput(), y); // swap B and T from y to Output
586 } else {
587 // tmp is a reference to y (full cudnn output)
588 Tensor_t tmp = (y.At(y.GetShape()[0] - 1)).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
589 Architecture_t::Copy(this->GetOutput(), tmp);
590 }
591
592 return;
593 }
594
595 // Standard CPU implementation
596
597 // D : input size
598 // H : state size
599 // T : time size
600 // B : batch size
601
602 Tensor_t arrInput( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
603 //Tensor_t &arrInput = this->GetX();
604
605 Architecture_t::Rearrange(arrInput, input); // B x T x D
606
607 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
608
609
610 if (!this->fRememberState) {
612 }
613
614 /*! Pass each gate values to CellForward() to calculate
615 * next hidden state and next cell state. */
616 for (size_t t = 0; t < fTimeSteps; ++t) {
617 /* Feed forward network: value of each gate being computed at each timestep t. */
618 Matrix_t arrInputMt = arrInput[t];
619 InputGate(arrInputMt, fDerivativesInput[t]);
620 ForgetGate(arrInputMt, fDerivativesForget[t]);
621 CandidateValue(arrInputMt, fDerivativesCandidate[t]);
622 OutputGate(arrInputMt, fDerivativesOutput[t]);
623
624 Architecture_t::Copy(this->GetInputGateTensorAt(t), fInputValue);
625 Architecture_t::Copy(this->GetForgetGateTensorAt(t), fForgetValue);
626 Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
627 Architecture_t::Copy(this->GetOutputGateTensorAt(t), fOutputValue);
628
629 CellForward(fInputValue, fForgetValue, fCandidateValue, fOutputValue);
630 Matrix_t arrOutputMt = arrOutput[t];
631 Architecture_t::Copy(arrOutputMt, fState);
632 Architecture_t::Copy(this->GetCellTensorAt(t), fCell);
633 }
634
635 // check if full output needs to be returned
636 if (fReturnSequence)
637 Architecture_t::Rearrange(this->GetOutput(), arrOutput); // B x T x D
638 else {
639 // get T[end[]]
640 Tensor_t tmp = arrOutput.At(fTimeSteps - 1); // take last time step
641 // shape of tmp is for CPU (columnwise) B x D , need to reshape to make a B x D x 1
642 // and transpose it to 1 x D x B (this is how output is expected in columnmajor format)
643 tmp = tmp.Reshape( {tmp.GetShape()[0], tmp.GetShape()[1], 1});
644 assert(tmp.GetSize() == this->GetOutput().GetSize());
645 assert( tmp.GetShape()[0] == this->GetOutput().GetShape()[2]); // B is last dim in output and first in tmp
646 Architecture_t::Rearrange(this->GetOutput(), tmp);
647 // keep array output
648 fY = arrOutput;
649 }
650}
651
652 //______________________________________________________________________________
653template <typename Architecture_t>
654auto inline TBasicLSTMLayer<Architecture_t>::CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues,
655 const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
656-> void
657{
658
659 // Update cell state.
660 Architecture_t::Hadamard(fCell, forgetGateValues);
661 Architecture_t::Hadamard(inputGateValues, candidateValues);
662 Architecture_t::ScaleAdd(fCell, inputGateValues);
663
664 Matrix_t cache(fCell.GetNrows(), fCell.GetNcols());
665 Architecture_t::Copy(cache, fCell);
666
667 // Update hidden state.
668 const DNN::EActivationFunction fAT = this->GetActivationFunctionF2();
669 DNN::evaluateMatrix<Architecture_t>(cache, fAT);
670
671 /*! The Hadamard product of output_gate_value . tanh(cell_state)
672 * will be copied to next hidden state (passed to next LSTM cell)
673 * and we will update our outputGateValues also. */
674 Architecture_t::Copy(fState, cache);
675 Architecture_t::Hadamard(fState, outputGateValues);
676}
677
678 //____________________________________________________________________________
679template <typename Architecture_t>
680auto inline TBasicLSTMLayer<Architecture_t>::Backward(Tensor_t &gradients_backward, // B x T x D
681 const Tensor_t &activations_backward) // B x T x D
682-> void
683{
684
685 // BACKWARD for CUDNN
686 if (Architecture_t::IsCudnn()) {
687
688 Tensor_t &x = this->fX;
689 Tensor_t &y = this->fY;
690 Tensor_t &dx = this->fDx;
691 Tensor_t &dy = this->fDy;
692
693 // input size is stride[1] of input tensor that is B x T x inputSize
694 assert(activations_backward.GetStrides()[1] == this->GetInputSize());
695
696 Architecture_t::Rearrange(x, activations_backward);
697
698 if (!fReturnSequence) {
699
700 // Architecture_t::InitializeZero(dy);
701 Architecture_t::InitializeZero(dy);
702
703 // Tensor_t tmp1 = y.At(y.GetShape()[0] - 1).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
704 // dy is a tensor of shape (rowmajor for Cudnn): T x B x S
705 // and this->ActivatuonGradients is B x (T=1) x S
706 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
707
708 // Architecture_t::Copy(tmp1, this->GetOutput());
709 Architecture_t::Copy(tmp2, this->GetActivationGradients());
710 } else {
711 Architecture_t::Rearrange(y, this->GetOutput());
712 Architecture_t::Rearrange(dy, this->GetActivationGradients());
713 }
714
715 // Architecture_t::PrintTensor(this->GetOutput(), "output before bwd");
716
717 // for cudnn Matrix_t and Tensor_t are same type
718 const auto &weights = this->GetWeightsTensor();
719 auto &weightGradients = this->GetWeightGradientsTensor();
720 // note that cudnnRNNBackwardWeights accumulate the weight gradients.
721 // We need then to initialize the tensor to zero every time
722 Architecture_t::InitializeZero(weightGradients);
723
724 // hx is fState
725 auto &hx = this->GetState();
726 auto &cx = this->GetCell();
727 //auto &cx = this->GetCell();
728 // use same for hy and cy
729 auto &dhy = hx;
730 auto &dcy = cx;
731 auto &dhx = hx;
732 auto &dcx = cx;
733
734 auto rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
735 auto rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
736
737 Architecture_t::RNNBackward(x, hx, cx, y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
738
739 // Architecture_t::PrintTensor(this->GetOutput(), "output after bwd");
740
741 if (gradients_backward.GetSize() != 0)
742 Architecture_t::Rearrange(gradients_backward, dx);
743
744 return;
745 }
746 // CPU implementation
747
748 // gradients_backward is activationGradients of layer before it, which is input layer.
749 // Currently, gradients_backward is for input(x) and not for state.
750 // For the state it can be:
751 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize); // B x H
752 DNN::initialize<Architecture_t>(state_gradients_backward, DNN::EInitialization::kZero); // B x H
753
754
755 Matrix_t cell_gradients_backward(this->GetBatchSize(), fStateSize); // B x H
756 DNN::initialize<Architecture_t>(cell_gradients_backward, DNN::EInitialization::kZero); // B x H
757
758 // if dummy is false gradients_backward will be written back on the matrix
759 bool dummy = false;
760 if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
761 dummy = true;
762 }
763
764
765 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
766
767
768 //Architecture_t::Rearrange(arr_gradients_backward, gradients_backward); // B x T x D
769 // activations_backward is input.
770 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
771
772 Architecture_t::Rearrange(arr_activations_backward, activations_backward); // B x T x D
773
774 /*! For backpropagation, we need to calculate loss. For loss, output must be known.
775 * We obtain outputs during forward propagation and place the results in arr_output tensor. */
776 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
777
778 Matrix_t initState(this->GetBatchSize(), fCellSize); // B x H
779 DNN::initialize<Architecture_t>(initState, DNN::EInitialization::kZero); // B x H
780
781 // This will take partial derivative of state[t] w.r.t state[t-1]
782
783 Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
784
785 if (fReturnSequence) {
786 Architecture_t::Rearrange(arr_output, this->GetOutput());
787 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
788 } else {
789 // here for CPU need to transpose the input activatuon gradients into the right format
790 arr_output = fY;
791 Architecture_t::InitializeZero(arr_actgradients);
792 // need to reshape to pad a time dimension = 1 (note here is columnmajor tensors)
793 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape( {this->GetBatchSize(), fStateSize, 1});
794 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
795 assert(tmp_grad.GetShape()[0] == this->GetActivationGradients().GetShape()[2]); // B in tmp is [0] and [2] in input act. gradients
796
797 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
798 }
799
800 /*! There are total 8 different weight matrices and 4 bias vectors.
801 * Re-initialize them with zero because it should have some value. (can't be garbage values) */
802
803 // Input Gate.
804 fWeightsInputGradients.Zero();
805 fWeightsInputStateGradients.Zero();
806 fInputBiasGradients.Zero();
807
808 // Forget Gate.
809 fWeightsForgetGradients.Zero();
810 fWeightsForgetStateGradients.Zero();
811 fForgetBiasGradients.Zero();
812
813 // Candidate Gate.
814 fWeightsCandidateGradients.Zero();
815 fWeightsCandidateStateGradients.Zero();
816 fCandidateBiasGradients.Zero();
817
818 // Output Gate.
819 fWeightsOutputGradients.Zero();
820 fWeightsOutputStateGradients.Zero();
821 fOutputBiasGradients.Zero();
822
823
824 for (size_t t = fTimeSteps; t > 0; t--) {
825 // Store the sum of gradients obtained at each timestep during backward pass.
826 Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
827 if (t > 1) {
828 const Matrix_t &prevStateActivations = arr_output[t-2];
829 const Matrix_t &prevCellActivations = this->GetCellTensorAt(t-2);
830 // During forward propagation, each gate value calculates their gradients.
831 Matrix_t dx = arr_gradients_backward[t-1];
832 CellBackward(state_gradients_backward, cell_gradients_backward,
833 prevStateActivations, prevCellActivations,
834 this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
835 this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
836 arr_activations_backward[t-1], dx,
837 fDerivativesInput[t-1], fDerivativesForget[t-1],
838 fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
839 } else {
840 const Matrix_t &prevStateActivations = initState;
841 const Matrix_t &prevCellActivations = initState;
842 Matrix_t dx = arr_gradients_backward[t-1];
843 CellBackward(state_gradients_backward, cell_gradients_backward,
844 prevStateActivations, prevCellActivations,
845 this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
846 this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
847 arr_activations_backward[t-1], dx,
848 fDerivativesInput[t-1], fDerivativesForget[t-1],
849 fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
850 }
851 }
852
853 if (!dummy) {
854 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
855 }
856
857}
858
859
860 //______________________________________________________________________________
861template <typename Architecture_t>
862auto inline TBasicLSTMLayer<Architecture_t>::CellBackward(Matrix_t & state_gradients_backward,
863 Matrix_t & cell_gradients_backward,
864 const Matrix_t & precStateActivations, const Matrix_t & precCellActivations,
865 const Matrix_t & input_gate, const Matrix_t & forget_gate,
866 const Matrix_t & candidate_gate, const Matrix_t & output_gate,
867 const Matrix_t & input, Matrix_t & input_gradient,
868 Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout,
869 size_t t)
870-> Matrix_t &
871{
872 /*! Call here LSTMLayerBackward() to pass parameters i.e. gradient
873 * values obtained from each gate during forward propagation. */
874
875
876 // cell gradient for current time step
877 const DNN::EActivationFunction fAT = this->GetActivationFunctionF2();
878 Matrix_t cell_gradient(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
879 DNN::evaluateDerivativeMatrix<Architecture_t>(cell_gradient, fAT, this->GetCellTensorAt(t));
880
881 // cell tanh value for current time step
882 Matrix_t cell_tanh(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
883 Architecture_t::Copy(cell_tanh, this->GetCellTensorAt(t));
884 DNN::evaluateMatrix<Architecture_t>(cell_tanh, fAT);
885
886 return Architecture_t::LSTMLayerBackward(state_gradients_backward, cell_gradients_backward,
887 fWeightsInputGradients, fWeightsForgetGradients, fWeightsCandidateGradients,
888 fWeightsOutputGradients, fWeightsInputStateGradients, fWeightsForgetStateGradients,
889 fWeightsCandidateStateGradients, fWeightsOutputStateGradients, fInputBiasGradients, fForgetBiasGradients,
890 fCandidateBiasGradients, fOutputBiasGradients, di, df, dc, dout,
891 precStateActivations, precCellActivations,
892 input_gate, forget_gate, candidate_gate, output_gate,
893 fWeightsInputGate, fWeightsForgetGate, fWeightsCandidate, fWeightsOutputGate,
894 fWeightsInputGateState, fWeightsForgetGateState, fWeightsCandidateState,
895 fWeightsOutputGateState, input, input_gradient,
896 cell_gradient, cell_tanh);
897}
898
899 //______________________________________________________________________________
900template <typename Architecture_t>
902-> void
903{
904 DNN::initialize<Architecture_t>(this->GetState(), DNN::EInitialization::kZero);
905 DNN::initialize<Architecture_t>(this->GetCell(), DNN::EInitialization::kZero);
906}
907
908 //______________________________________________________________________________
909template<typename Architecture_t>
911-> void
912{
913 std::cout << " LSTM Layer: \t ";
914 std::cout << " (NInput = " << this->GetInputSize(); // input size
915 std::cout << ", NState = " << this->GetStateSize(); // hidden state size
916 std::cout << ", NTime = " << this->GetTimeSteps() << " )"; // time size
917 std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput()[0].GetNrows() << " , " << this->GetOutput()[0].GetNcols() << " )\n";
918}
919
920 //______________________________________________________________________________
921template <typename Architecture_t>
923-> void
924{
925 auto layerxml = gTools().xmlengine().NewChild(parent, 0, "LSTMLayer");
926
927 // Write all other info like outputSize, cellSize, inputSize, timeSteps, rememberState
928 gTools().xmlengine().NewAttr(layerxml, 0, "StateSize", gTools().StringFromInt(this->GetStateSize()));
929 gTools().xmlengine().NewAttr(layerxml, 0, "CellSize", gTools().StringFromInt(this->GetCellSize()));
930 gTools().xmlengine().NewAttr(layerxml, 0, "InputSize", gTools().StringFromInt(this->GetInputSize()));
931 gTools().xmlengine().NewAttr(layerxml, 0, "TimeSteps", gTools().StringFromInt(this->GetTimeSteps()));
932 gTools().xmlengine().NewAttr(layerxml, 0, "RememberState", gTools().StringFromInt(this->DoesRememberState()));
933 gTools().xmlengine().NewAttr(layerxml, 0, "ReturnSequence", gTools().StringFromInt(this->DoesReturnSequence()));
934
935 // write weights and bias matrices
936 this->WriteMatrixToXML(layerxml, "InputWeights", this->GetWeightsAt(0));
937 this->WriteMatrixToXML(layerxml, "InputStateWeights", this->GetWeightsAt(1));
938 this->WriteMatrixToXML(layerxml, "InputBiases", this->GetBiasesAt(0));
939 this->WriteMatrixToXML(layerxml, "ForgetWeights", this->GetWeightsAt(2));
940 this->WriteMatrixToXML(layerxml, "ForgetStateWeights", this->GetWeightsAt(3));
941 this->WriteMatrixToXML(layerxml, "ForgetBiases", this->GetBiasesAt(1));
942 this->WriteMatrixToXML(layerxml, "CandidateWeights", this->GetWeightsAt(4));
943 this->WriteMatrixToXML(layerxml, "CandidateStateWeights", this->GetWeightsAt(5));
944 this->WriteMatrixToXML(layerxml, "CandidateBiases", this->GetBiasesAt(2));
945 this->WriteMatrixToXML(layerxml, "OuputWeights", this->GetWeightsAt(6));
946 this->WriteMatrixToXML(layerxml, "OutputStateWeights", this->GetWeightsAt(7));
947 this->WriteMatrixToXML(layerxml, "OutputBiases", this->GetBiasesAt(3));
948}
949
950 //______________________________________________________________________________
951template <typename Architecture_t>
953-> void
954{
955 // Read weights and biases
956 this->ReadMatrixXML(parent, "InputWeights", this->GetWeightsAt(0));
957 this->ReadMatrixXML(parent, "InputStateWeights", this->GetWeightsAt(1));
958 this->ReadMatrixXML(parent, "InputBiases", this->GetBiasesAt(0));
959 this->ReadMatrixXML(parent, "ForgetWeights", this->GetWeightsAt(2));
960 this->ReadMatrixXML(parent, "ForgetStateWeights", this->GetWeightsAt(3));
961 this->ReadMatrixXML(parent, "ForgetBiases", this->GetBiasesAt(1));
962 this->ReadMatrixXML(parent, "CandidateWeights", this->GetWeightsAt(4));
963 this->ReadMatrixXML(parent, "CandidateStateWeights", this->GetWeightsAt(5));
964 this->ReadMatrixXML(parent, "CandidateBiases", this->GetBiasesAt(2));
965 this->ReadMatrixXML(parent, "OuputWeights", this->GetWeightsAt(6));
966 this->ReadMatrixXML(parent, "OutputStateWeights", this->GetWeightsAt(7));
967 this->ReadMatrixXML(parent, "OutputBiases", this->GetBiasesAt(3));
968}
969
970} // namespace LSTM
971} // namespace DNN
972} // namespace TMVA
973
974#endif // LSTM_LAYER_H
void InputGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
Definition LSTMLayer.h:479
const Matrix_t & GetForgetGateTensorAt(size_t i) const
Definition LSTMLayer.h:276
Matrix_t & GetWeightsOutputGateState()
Definition LSTMLayer.h:251
const std::vector< Matrix_t > & GetOutputGateTensor() const
Definition LSTMLayer.h:282
Tensor_t fWeightsTensor
Tensor for all weights.
Definition LSTMLayer.h:132
const std::vector< Matrix_t > & GetInputGateTensor() const
Definition LSTMLayer.h:270
std::vector< Matrix_t > & GetDerivativesOutput()
Definition LSTMLayer.h:266
const Matrix_t & GetWeigthsForgetStateGradients() const
Definition LSTMLayer.h:307
typename Architecture_t::Matrix_t Matrix_t
Definition LSTMLayer.h:61
Matrix_t & GetCandidateGateTensorAt(size_t i)
Definition LSTMLayer.h:281
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Definition LSTMLayer.h:901
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
Definition LSTMLayer.h:124
const Matrix_t & GetOutputGateBias() const
Definition LSTMLayer.h:297
Matrix_t & GetWeightsCandidateStateGradients()
Definition LSTMLayer.h:314
Matrix_t & GetWeightsInputGateState()
Definition LSTMLayer.h:245
const std::vector< Matrix_t > & GetCandidateGateTensor() const
Definition LSTMLayer.h:278
const Matrix_t & GetInputGateTensorAt(size_t i) const
Definition LSTMLayer.h:272
std::vector< Matrix_t > & GetForgetGateTensor()
Definition LSTMLayer.h:275
std::vector< Matrix_t > cell_value
cell value for every time step
Definition LSTMLayer.h:112
Matrix_t & fWeightsOutputGradients
Gradients w.r.t the output gate - input weights.
Definition LSTMLayer.h:127
Matrix_t & fOutputBiasGradients
Gradients w.r.t the output gate - bias weights.
Definition LSTMLayer.h:129
DNN::EActivationFunction fF1
Activation function: sigmoid.
Definition LSTMLayer.h:82
virtual void Initialize()
Initialize the weights according to the given initialization method.
Definition LSTMLayer.h:469
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Definition LSTMLayer.h:139
Matrix_t & fWeightsOutputGate
Output Gate weights for input, fWeights[6].
Definition LSTMLayer.h:104
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
Definition LSTMLayer.h:125
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
Definition LSTMLayer.h:553
const Matrix_t & GetInputGateBias() const
Definition LSTMLayer.h:291
typename Architecture_t::Scalar_t Scalar_t
Definition LSTMLayer.h:62
size_t GetInputSize() const
Getters.
Definition LSTMLayer.h:211
Matrix_t & GetForgetGateTensorAt(size_t i)
Definition LSTMLayer.h:277
const Matrix_t & GetOutputGateTensorAt(size_t i) const
Definition LSTMLayer.h:284
const Matrix_t & GetCellTensorAt(size_t i) const
Definition LSTMLayer.h:288
Tensor_t fX
cached input tensor as T x B x I
Definition LSTMLayer.h:136
DNN::EActivationFunction GetActivationFunctionF2() const
Definition LSTMLayer.h:220
Matrix_t & GetCellTensorAt(size_t i)
Definition LSTMLayer.h:289
Matrix_t & fWeightsInputStateGradients
Gradients w.r.t the input gate - hidden state weights.
Definition LSTMLayer.h:119
void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues, const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
Forward for a single cell (time unit)
Definition LSTMLayer.h:654
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, Matrix_t &cell_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &precCellActivations, const Matrix_t &input_gate, const Matrix_t &forget_gate, const Matrix_t &candidate_gate, const Matrix_t &output_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t)
Backward for a single time unit a the corresponding call to Forward(...).
Definition LSTMLayer.h:862
const Matrix_t & GetWeightsInputStateGradients() const
Definition LSTMLayer.h:301
std::vector< Matrix_t > fDerivativesOutput
First fDerivatives of the activations output gate.
Definition LSTMLayer.h:116
Matrix_t & fWeightsForgetGateState
Forget Gate weights for prev state, fWeights[3].
Definition LSTMLayer.h:97
Matrix_t & fOutputGateBias
Output Gate bias.
Definition LSTMLayer.h:106
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
Definition LSTMLayer.h:115
const Matrix_t & GetInputDerivativesAt(size_t i) const
Definition LSTMLayer.h:255
Matrix_t & fWeightsForgetGate
Forget Gate weights for input, fWeights[2].
Definition LSTMLayer.h:96
Matrix_t & fWeightsInputGradients
Gradients w.r.t the input gate - input weights.
Definition LSTMLayer.h:118
typename Architecture_t::Tensor_t Tensor_t
Definition LSTMLayer.h:63
const std::vector< Matrix_t > & GetDerivativesInput() const
Definition LSTMLayer.h:253
Matrix_t & fForgetGateBias
Forget Gate bias.
Definition LSTMLayer.h:98
Matrix_t & GetWeightsInputGradients()
Definition LSTMLayer.h:300
Matrix_t & GetCandidateBiasGradients()
Definition LSTMLayer.h:316
Matrix_t & GetWeightsOutputGradients()
Definition LSTMLayer.h:318
Matrix_t & fCandidateBias
Candidate Gate bias.
Definition LSTMLayer.h:102
Matrix_t fCandidateValue
Computed candidate values.
Definition LSTMLayer.h:86
Tensor_t & GetWeightGradientsTensor()
Definition LSTMLayer.h:326
const Matrix_t & GetWeightsOutputGradients() const
Definition LSTMLayer.h:317
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Definition LSTMLayer.h:65
const Matrix_t & GetWeightsInputGradients() const
Definition LSTMLayer.h:299
Matrix_t & GetWeightsCandidateState()
Definition LSTMLayer.h:249
const Matrix_t & GetInputBiasGradients() const
Definition LSTMLayer.h:303
DNN::EActivationFunction fF2
Activaton function: tanh.
Definition LSTMLayer.h:83
Matrix_t & fInputBiasGradients
Gradients w.r.t the input gate - bias weights.
Definition LSTMLayer.h:120
Matrix_t & GetWeightsOutputStateGradients()
Definition LSTMLayer.h:320
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
Definition LSTMLayer.h:101
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Definition LSTMLayer.h:922
std::vector< Matrix_t > fDerivativesForget
First fDerivatives of the activations forget gate.
Definition LSTMLayer.h:114
const Tensor_t & GetWeightGradientsTensor() const
Definition LSTMLayer.h:327
Matrix_t & GetForgetDerivativesAt(size_t i)
Definition LSTMLayer.h:260
const Matrix_t & GetWeightsInputGateState() const
Definition LSTMLayer.h:244
Matrix_t & GetWeightsInputStateGradients()
Definition LSTMLayer.h:302
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Definition LSTMLayer.h:68
Matrix_t & fForgetBiasGradients
Gradients w.r.t the forget gate - bias weights.
Definition LSTMLayer.h:123
const Matrix_t & GetCandidateBias() const
Definition LSTMLayer.h:295
std::vector< Matrix_t > output_gate_value
output gate value for every time step
Definition LSTMLayer.h:111
const std::vector< Matrix_t > & GetDerivativesCandidate() const
Definition LSTMLayer.h:261
size_t fStateSize
Hidden state size for LSTM.
Definition LSTMLayer.h:75
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
Definition LSTMLayer.h:515
std::vector< Matrix_t > fDerivativesInput
First fDerivatives of the activations input gate.
Definition LSTMLayer.h:113
const Matrix_t & GetWeightsForgetGateState() const
Definition LSTMLayer.h:246
Matrix_t & GetWeightsForgetGateState()
Definition LSTMLayer.h:247
const Matrix_t & GetWeightsInputGate() const
Definition LSTMLayer.h:236
const Matrix_t & GetInputGateValue() const
Definition LSTMLayer.h:222
void Update(const Scalar_t learningRate)
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
Definition LSTMLayer.h:138
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Definition LSTMLayer.h:70
TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Definition LSTMLayer.h:341
Matrix_t & GetWeightsForgetStateGradients()
Definition LSTMLayer.h:308
const Matrix_t & GetOutputBiasGradients() const
Definition LSTMLayer.h:321
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
Definition LSTMLayer.h:67
const Matrix_t & GetWeightsOutputStateGradients() const
Definition LSTMLayer.h:319
Matrix_t & fWeightsOutputStateGradients
Gradients w.r.t the output gate - hidden state weights.
Definition LSTMLayer.h:128
bool fReturnSequence
Return in output full sequence or just last element.
Definition LSTMLayer.h:80
Matrix_t & GetWeightsForgetGradients()
Definition LSTMLayer.h:306
Matrix_t & GetWeightsCandidateGradients()
Definition LSTMLayer.h:312
const Matrix_t & GetWeightsForgetGradients() const
Definition LSTMLayer.h:305
Matrix_t fCell
Cell state of LSTM.
Definition LSTMLayer.h:90
std::vector< Matrix_t > & GetDerivativesCandidate()
Definition LSTMLayer.h:262
const Matrix_t & GetForgetBiasGradients() const
Definition LSTMLayer.h:309
std::vector< Matrix_t > & GetOutputGateTensor()
Definition LSTMLayer.h:283
const Matrix_t & GetForgetDerivativesAt(size_t i) const
Definition LSTMLayer.h:259
Matrix_t fState
Hidden state of LSTM.
Definition LSTMLayer.h:89
void OutputGate(const Matrix_t &input, Matrix_t &dout)
Computes output values (NN with Sigmoid)
Definition LSTMLayer.h:533
const Matrix_t & GetForgetGateValue() const
Definition LSTMLayer.h:226
std::vector< Matrix_t > candidate_gate_value
candidate gate value for every time step
Definition LSTMLayer.h:110
const Matrix_t & GetState() const
Definition LSTMLayer.h:231
const Matrix_t & GetWeightsCandidateState() const
Definition LSTMLayer.h:248
const std::vector< Matrix_t > & GetForgetGateTensor() const
Definition LSTMLayer.h:274
const std::vector< Matrix_t > & GetDerivativesOutput() const
Definition LSTMLayer.h:265
const std::vector< Matrix_t > & GetCellTensor() const
Definition LSTMLayer.h:286
const Tensor_t & GetWeightsTensor() const
Definition LSTMLayer.h:325
Matrix_t & fWeightsInputGate
Input Gate weights for input, fWeights[0].
Definition LSTMLayer.h:92
std::vector< Matrix_t > & GetCandidateGateTensor()
Definition LSTMLayer.h:279
const Matrix_t & GetOutputDerivativesAt(size_t i) const
Definition LSTMLayer.h:267
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
Definition LSTMLayer.h:952
const Matrix_t & GetCell() const
Definition LSTMLayer.h:233
Matrix_t & fWeightsForgetStateGradients
Gradients w.r.t the forget gate - hidden state weights.
Definition LSTMLayer.h:122
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Definition LSTMLayer.h:280
Matrix_t fOutputValue
Computed output gate values.
Definition LSTMLayer.h:88
size_t fCellSize
Cell state size of LSTM.
Definition LSTMLayer.h:76
Matrix_t & GetOutputDerivativesAt(size_t i)
Definition LSTMLayer.h:268
Matrix_t & GetInputGateTensorAt(size_t i)
Definition LSTMLayer.h:273
std::vector< Matrix_t > & GetDerivativesInput()
Definition LSTMLayer.h:254
Matrix_t & fWeightsOutputGateState
Output Gate weights for prev state, fWeights[7].
Definition LSTMLayer.h:105
const std::vector< Matrix_t > & GetDerivativesForget() const
Definition LSTMLayer.h:257
const Matrix_t & GetForgetGateBias() const
Definition LSTMLayer.h:293
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Definition LSTMLayer.h:263
Matrix_t & GetOutputGateTensorAt(size_t i)
Definition LSTMLayer.h:285
size_t fTimeSteps
Timesteps for LSTM.
Definition LSTMLayer.h:77
const Matrix_t & GetCandidateBiasGradients() const
Definition LSTMLayer.h:315
const Matrix_t & GetCandidateValue() const
Definition LSTMLayer.h:224
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Definition LSTMLayer.h:66
Matrix_t & fInputGateBias
Input Gate bias.
Definition LSTMLayer.h:94
const Matrix_t & GetWeightsForgetGate() const
Definition LSTMLayer.h:240
std::vector< Matrix_t > input_gate_value
input gate value for every time step
Definition LSTMLayer.h:108
const Matrix_t & GetWeightsCandidateStateGradients() const
Definition LSTMLayer.h:313
Matrix_t & fWeightsForgetGradients
Gradients w.r.t the forget gate - input weights.
Definition LSTMLayer.h:121
std::vector< Matrix_t > & GetDerivativesForget()
Definition LSTMLayer.h:258
const Matrix_t & GetWeightsOutputGate() const
Definition LSTMLayer.h:242
void ForgetGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
Definition LSTMLayer.h:497
std::vector< Matrix_t > & GetInputGateTensor()
Definition LSTMLayer.h:271
const Matrix_t & GetOutputGateValue() const
Definition LSTMLayer.h:228
const Matrix_t & GetWeightsOutputGateState() const
Definition LSTMLayer.h:250
Matrix_t & GetCandidateDerivativesAt(size_t i)
Definition LSTMLayer.h:264
Matrix_t fInputValue
Computed input gate values.
Definition LSTMLayer.h:85
const Matrix_t & GetWeightsCandidate() const
Definition LSTMLayer.h:238
void Print() const
Prints the info about the layer.
Definition LSTMLayer.h:910
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
Definition LSTMLayer.h:680
const Matrix_t & GetWeightsCandidateGradients() const
Definition LSTMLayer.h:311
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Definition LSTMLayer.h:133
Matrix_t & GetInputDerivativesAt(size_t i)
Definition LSTMLayer.h:256
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
Definition LSTMLayer.h:71
DNN::EActivationFunction GetActivationFunctionF1() const
Definition LSTMLayer.h:219
Tensor_t fY
cached output tensor as T x B x S
Definition LSTMLayer.h:137
std::vector< Matrix_t > forget_gate_value
forget gate value for every time step
Definition LSTMLayer.h:109
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
Definition LSTMLayer.h:100
bool fRememberState
Remember state in next pass.
Definition LSTMLayer.h:79
Matrix_t & fWeightsInputGateState
Input Gate weights for prev state, fWeights[1].
Definition LSTMLayer.h:93
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
Definition LSTMLayer.h:141
std::vector< Matrix_t > & GetCellTensor()
Definition LSTMLayer.h:287
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Definition LSTMLayer.h:126
Matrix_t fForgetValue
Computed forget gate values.
Definition LSTMLayer.h:87
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetBatchSize() const
Getters.
size_t GetInputWidth() const
TXMLEngine & xmlengine()
Definition Tools.h:262
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
TF1 * f1
Definition legend1.C:11
EActivationFunction
Enum that represents layer activation functions.
Definition Functions.h:32
create variable transformations
Tools & gTools()
auto * m
Definition textangle.C:8