30#ifndef TMVA_DNN_LSTM_LAYER 
   31#define TMVA_DNN_LSTM_LAYER 
   55template<
typename Architecture_t>
 
   61   using Matrix_t = 
typename Architecture_t::Matrix_t;
 
   62   using Scalar_t = 
typename Architecture_t::Scalar_t;
 
   63   using Tensor_t = 
typename Architecture_t::Tensor_t;
 
 
  340template <
typename Architecture_t>
 
  349        {
stateSize, 
stateSize, 
stateSize, 
stateSize}, {1, 1, 1, 1}, batchSize, (
returnSequence) ? 
timeSteps : 1,
 
  355     fWeightsInputGateState(
this->GetWeightsAt(4)), fInputGateBias(
this->GetBiasesAt(0)),
 
  356     fWeightsForgetGate(
this->GetWeightsAt(1)), fWeightsForgetGateState(
this->GetWeightsAt(5)),
 
  357     fForgetGateBias(
this->GetBiasesAt(1)), fWeightsCandidate(
this->GetWeightsAt(2)),
 
  358     fWeightsCandidateState(
this->GetWeightsAt(6)), fCandidateBias(
this->GetBiasesAt(2)),
 
  359     fWeightsOutputGate(
this->GetWeightsAt(3)), fWeightsOutputGateState(
this->GetWeightsAt(7)),
 
  360     fOutputGateBias(
this->GetBiasesAt(3)), fWeightsInputGradients(
this->GetWeightGradientsAt(0)),
 
  361     fWeightsInputStateGradients(
this->GetWeightGradientsAt(4)), fInputBiasGradients(
this->GetBiasGradientsAt(0)),
 
  362     fWeightsForgetGradients(
this->GetWeightGradientsAt(1)),
 
  363     fWeightsForgetStateGradients(
this->GetWeightGradientsAt(5)), fForgetBiasGradients(
this->GetBiasGradientsAt(1)),
 
  364     fWeightsCandidateGradients(
this->GetWeightGradientsAt(2)),
 
  365     fWeightsCandidateStateGradients(
this->GetWeightGradientsAt(6)),
 
  366     fCandidateBiasGradients(
this->GetBiasGradientsAt(2)), fWeightsOutputGradients(
this->GetWeightGradientsAt(3)),
 
  367     fWeightsOutputStateGradients(
this->GetWeightGradientsAt(7)), fOutputBiasGradients(
this->GetBiasGradientsAt(3))
 
  380   Architecture_t::InitializeLSTMTensors(
this);
 
 
  384template <
typename Architecture_t>
 
  387      fStateSize(
layer.fStateSize),
 
  388      fCellSize(
layer.fCellSize),
 
  389      fTimeSteps(
layer.fTimeSteps),
 
  390      fRememberState(
layer.fRememberState),
 
  391      fReturnSequence(
layer.fReturnSequence),
 
  392      fF1(
layer.GetActivationFunctionF1()),
 
  393      fF2(
layer.GetActivationFunctionF2()),
 
  394      fInputValue(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  395      fCandidateValue(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  396      fForgetValue(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  397      fOutputValue(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  398      fState(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  399      fCell(
layer.GetBatchSize(), 
layer.GetCellSize()),
 
  400      fWeightsInputGate(
this->GetWeightsAt(0)),
 
  401      fWeightsInputGateState(
this->GetWeightsAt(4)),
 
  402      fInputGateBias(
this->GetBiasesAt(0)),
 
  403      fWeightsForgetGate(
this->GetWeightsAt(1)),
 
  404      fWeightsForgetGateState(
this->GetWeightsAt(5)),
 
  405      fForgetGateBias(
this->GetBiasesAt(1)),
 
  406      fWeightsCandidate(
this->GetWeightsAt(2)),
 
  407      fWeightsCandidateState(
this->GetWeightsAt(6)),
 
  408      fCandidateBias(
this->GetBiasesAt(2)),
 
  409      fWeightsOutputGate(
this->GetWeightsAt(3)),
 
  410      fWeightsOutputGateState(
this->GetWeightsAt(7)),
 
  411      fOutputGateBias(
this->GetBiasesAt(3)),
 
  412      fWeightsInputGradients(
this->GetWeightGradientsAt(0)),
 
  413      fWeightsInputStateGradients(
this->GetWeightGradientsAt(4)),
 
  414      fInputBiasGradients(
this->GetBiasGradientsAt(0)),
 
  415      fWeightsForgetGradients(
this->GetWeightGradientsAt(1)),
 
  416      fWeightsForgetStateGradients(
this->GetWeightGradientsAt(5)),
 
  417      fForgetBiasGradients(
this->GetBiasGradientsAt(1)),
 
  418      fWeightsCandidateGradients(
this->GetWeightGradientsAt(2)),
 
  419      fWeightsCandidateStateGradients(
this->GetWeightGradientsAt(6)),
 
  420      fCandidateBiasGradients(
this->GetBiasGradientsAt(2)),
 
  421      fWeightsOutputGradients(
this->GetWeightGradientsAt(3)),
 
  422      fWeightsOutputStateGradients(
this->GetWeightGradientsAt(7)),
 
  423      fOutputBiasGradients(
this->GetBiasGradientsAt(3))
 
  464   Architecture_t::InitializeLSTMTensors(
this);
 
 
  468template <
typename Architecture_t>
 
  473   Architecture_t::InitializeLSTMDescriptors(fDescriptors, 
this);
 
  474   Architecture_t::InitializeLSTMWorkspace(fWorkspace, fDescriptors, 
this);
 
 
  478template <
typename Architecture_t>
 
  487   Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsInputGateState);
 
  488   Architecture_t::MultiplyTranspose(fInputValue, 
input, fWeightsInputGate);
 
  489   Architecture_t::ScaleAdd(fInputValue, 
tmpState);
 
  490   Architecture_t::AddRowWise(fInputValue, fInputGateBias);
 
  491   DNN::evaluateDerivativeMatrix<Architecture_t>(
di, fInp, fInputValue);
 
  492   DNN::evaluateMatrix<Architecture_t>(fInputValue, fInp);
 
 
  496template <
typename Architecture_t>
 
  505   Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsForgetGateState);
 
  506   Architecture_t::MultiplyTranspose(fForgetValue, 
input, fWeightsForgetGate);
 
  507   Architecture_t::ScaleAdd(fForgetValue, 
tmpState);
 
  508   Architecture_t::AddRowWise(fForgetValue, fForgetGateBias);
 
  509   DNN::evaluateDerivativeMatrix<Architecture_t>(df, 
fFor, fForgetValue);
 
  510   DNN::evaluateMatrix<Architecture_t>(fForgetValue, 
fFor);
 
 
  514template <
typename Architecture_t>
 
  523   Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsCandidateState);
 
  524   Architecture_t::MultiplyTranspose(fCandidateValue, 
input, fWeightsCandidate);
 
  525   Architecture_t::ScaleAdd(fCandidateValue, 
tmpState);
 
  526   Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
 
  527   DNN::evaluateDerivativeMatrix<Architecture_t>(
dc, fCan, fCandidateValue);
 
  528   DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
 
 
  532template <
typename Architecture_t>
 
  541   Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsOutputGateState);
 
  542   Architecture_t::MultiplyTranspose(fOutputValue, 
input, fWeightsOutputGate);
 
  543   Architecture_t::ScaleAdd(fOutputValue, 
tmpState);
 
  544   Architecture_t::AddRowWise(fOutputValue, fOutputGateBias);
 
  545   DNN::evaluateDerivativeMatrix<Architecture_t>(
dout, fOut, fOutputValue);
 
  546   DNN::evaluateMatrix<Architecture_t>(fOutputValue, fOut);
 
 
  552template <
typename Architecture_t>
 
  558   if (Architecture_t::IsCudnn()) {
 
  565      Architecture_t::Rearrange(
x, 
input);
 
  568      const auto &weights = this->GetWeightsTensor();
 
  573      auto &
hx = this->fState;
 
  575      auto &
cx = this->fCell; 
 
  577      auto &
hy = this->fState;
 
  578      auto &
cy = this->fCell;
 
  585      if (fReturnSequence) {
 
  586         Architecture_t::Rearrange(this->GetOutput(), 
y); 
 
  589         Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1, 
y.GetShape()[2]});
 
  590         Architecture_t::Copy(this->GetOutput(), tmp);
 
  603   Tensor_t arrInput( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
 
  611   if (!this->fRememberState) {
 
  617   for (
size_t t = 0; t < fTimeSteps; ++t) {
 
  621      ForgetGate(
arrInputMt, fDerivativesForget[t]);
 
  622      CandidateValue(
arrInputMt, fDerivativesCandidate[t]);
 
  623      OutputGate(
arrInputMt, fDerivativesOutput[t]);
 
  625      Architecture_t::Copy(this->GetInputGateTensorAt(t), fInputValue);
 
  626      Architecture_t::Copy(this->GetForgetGateTensorAt(t), fForgetValue);
 
  627      Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
 
  628      Architecture_t::Copy(this->GetOutputGateTensorAt(t), fOutputValue);
 
  630      CellForward(fInputValue, fForgetValue, fCandidateValue, fOutputValue);
 
  633      Architecture_t::Copy(this->GetCellTensorAt(t), fCell);
 
  638      Architecture_t::Rearrange(this->GetOutput(), 
arrOutput); 
 
  644      tmp = tmp.Reshape( {tmp.GetShape()[0], tmp.GetShape()[1], 1});
 
  645      assert(tmp.GetSize() == 
this->GetOutput().GetSize());
 
  646      assert( tmp.GetShape()[0] == 
this->GetOutput().GetShape()[2]);  
 
  647      Architecture_t::Rearrange(this->GetOutput(), tmp);
 
 
  654template <
typename Architecture_t>
 
  665   Matrix_t cache(fCell.GetNrows(), fCell.GetNcols());
 
  666   Architecture_t::Copy(cache, fCell);
 
  670   DNN::evaluateMatrix<Architecture_t>(cache, 
fAT);
 
  675   Architecture_t::Copy(fState, cache);
 
 
  680template <
typename Architecture_t>
 
  687   if (Architecture_t::IsCudnn()) {
 
  699      if (!fReturnSequence) {
 
  702         Architecture_t::InitializeZero(
dy);
 
  710         Architecture_t::Copy(
tmp2, this->GetActivationGradients());
 
  712         Architecture_t::Rearrange(
y, this->GetOutput());
 
  713         Architecture_t::Rearrange(
dy, this->GetActivationGradients());
 
  719      const auto &weights = this->GetWeightsTensor();
 
  726      auto &
hx = this->GetState();
 
  727      auto &
cx = this->GetCell();
 
  738      Architecture_t::RNNBackward(
x, 
hx, 
cx, 
y, 
dy, 
dhy, 
dcy, weights, 
dx, 
dhx, 
dcx, 
weightGradients, 
rnnDesc, 
rnnWork);
 
  786   if (fReturnSequence) {
 
  787      Architecture_t::Rearrange(
arr_output, this->GetOutput());
 
  788      Architecture_t::Rearrange(
arr_actgradients, this->GetActivationGradients());
 
  798      Architecture_t::Rearrange(
tmp_grad, this->GetActivationGradients());
 
  805   fWeightsInputGradients.Zero();
 
  806   fWeightsInputStateGradients.Zero();
 
  807   fInputBiasGradients.Zero();
 
  810   fWeightsForgetGradients.Zero();
 
  811   fWeightsForgetStateGradients.Zero();
 
  812   fForgetBiasGradients.Zero();
 
  815   fWeightsCandidateGradients.Zero();
 
  816   fWeightsCandidateStateGradients.Zero();
 
  817   fCandidateBiasGradients.Zero();
 
  820   fWeightsOutputGradients.Zero();
 
  821   fWeightsOutputStateGradients.Zero();
 
  822   fOutputBiasGradients.Zero();
 
  825   for (
size_t t = fTimeSteps; t > 0; t--) {
 
  835                      this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
 
  836                      this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
 
  838                      fDerivativesInput[t-1], fDerivativesForget[t-1],
 
  839                      fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
 
  846                      this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
 
  847                      this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
 
  849                      fDerivativesInput[t-1], fDerivativesForget[t-1],
 
  850                      fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
 
 
  862template <
typename Architecture_t>
 
  880   DNN::evaluateDerivativeMatrix<Architecture_t>(
cell_gradient, 
fAT, this->GetCellTensorAt(t));
 
  883   Matrix_t cell_tanh(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
 
  884   Architecture_t::Copy(
cell_tanh, this->GetCellTensorAt(t));
 
  888                                            fWeightsInputGradients, fWeightsForgetGradients, fWeightsCandidateGradients,
 
  889                                            fWeightsOutputGradients, fWeightsInputStateGradients, fWeightsForgetStateGradients,
 
  890                                            fWeightsCandidateStateGradients, fWeightsOutputStateGradients, fInputBiasGradients, fForgetBiasGradients,
 
  891                                            fCandidateBiasGradients, fOutputBiasGradients, 
di, df, 
dc, 
dout,
 
  894                                            fWeightsInputGate, fWeightsForgetGate, fWeightsCandidate, fWeightsOutputGate,
 
  895                                            fWeightsInputGateState, fWeightsForgetGateState, fWeightsCandidateState,
 
 
  901template <
typename Architecture_t>
 
  910template<
typename Architecture_t>
 
  914   std::cout << 
" LSTM Layer: \t ";
 
  915   std::cout << 
" (NInput = " << this->GetInputSize();  
 
  916   std::cout << 
", NState = " << this->GetStateSize();  
 
  917   std::cout << 
", NTime  = " << this->GetTimeSteps() << 
" )";  
 
  918   std::cout << 
"\tOutput = ( " << this->GetOutput().GetFirstSize() << 
" , " << this->GetOutput()[0].GetNrows() << 
" , " << this->GetOutput()[0].GetNcols() << 
" )\n";
 
 
  922template <
typename Architecture_t>
 
  937   this->WriteMatrixToXML(
layerxml, 
"InputWeights", this->GetWeightsAt(0));
 
  938   this->WriteMatrixToXML(
layerxml, 
"InputStateWeights", this->GetWeightsAt(1));
 
  939   this->WriteMatrixToXML(
layerxml, 
"InputBiases", this->GetBiasesAt(0));
 
  940   this->WriteMatrixToXML(
layerxml, 
"ForgetWeights", this->GetWeightsAt(2));
 
  941   this->WriteMatrixToXML(
layerxml, 
"ForgetStateWeights", this->GetWeightsAt(3));
 
  942   this->WriteMatrixToXML(
layerxml, 
"ForgetBiases", this->GetBiasesAt(1));
 
  943   this->WriteMatrixToXML(
layerxml, 
"CandidateWeights", this->GetWeightsAt(4));
 
  944   this->WriteMatrixToXML(
layerxml, 
"CandidateStateWeights", this->GetWeightsAt(5));
 
  945   this->WriteMatrixToXML(
layerxml, 
"CandidateBiases", this->GetBiasesAt(2));
 
  946   this->WriteMatrixToXML(
layerxml, 
"OuputWeights", this->GetWeightsAt(6));
 
  947   this->WriteMatrixToXML(
layerxml, 
"OutputStateWeights", this->GetWeightsAt(7));
 
  948   this->WriteMatrixToXML(
layerxml, 
"OutputBiases", this->GetBiasesAt(3));
 
 
  952template <
typename Architecture_t>
 
  957   this->ReadMatrixXML(parent, 
"InputWeights", this->GetWeightsAt(0));
 
  958   this->ReadMatrixXML(parent, 
"InputStateWeights", this->GetWeightsAt(1));
 
  959   this->ReadMatrixXML(parent, 
"InputBiases", this->GetBiasesAt(0));
 
  960   this->ReadMatrixXML(parent, 
"ForgetWeights", this->GetWeightsAt(2));
 
  961   this->ReadMatrixXML(parent, 
"ForgetStateWeights", this->GetWeightsAt(3));
 
  962   this->ReadMatrixXML(parent, 
"ForgetBiases", this->GetBiasesAt(1));
 
  963   this->ReadMatrixXML(parent, 
"CandidateWeights", this->GetWeightsAt(4));
 
  964   this->ReadMatrixXML(parent, 
"CandidateStateWeights", this->GetWeightsAt(5));
 
  965   this->ReadMatrixXML(parent, 
"CandidateBiases", this->GetBiasesAt(2));
 
  966   this->ReadMatrixXML(parent, 
"OuputWeights", this->GetWeightsAt(6));
 
  967   this->ReadMatrixXML(parent, 
"OutputStateWeights", this->GetWeightsAt(7));
 
  968   this->ReadMatrixXML(parent, 
"OutputBiases", this->GetBiasesAt(3));
 
 
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
void InputGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
const Matrix_t & GetForgetGateTensorAt(size_t i) const
Matrix_t & GetWeightsOutputGateState()
const std::vector< Matrix_t > & GetOutputGateTensor() const
Tensor_t fWeightsTensor
Tensor for all weights.
const std::vector< Matrix_t > & GetInputGateTensor() const
std::vector< Matrix_t > & GetDerivativesOutput()
const Matrix_t & GetWeigthsForgetStateGradients() const
Matrix_t & GetWeightsForgetGate()
typename Architecture_t::Matrix_t Matrix_t
Matrix_t & GetCandidateGateTensorAt(size_t i)
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
const Matrix_t & GetOutputGateBias() const
Matrix_t & GetWeightsCandidateStateGradients()
Matrix_t & GetWeightsInputGate()
Matrix_t & GetWeightsInputGateState()
const std::vector< Matrix_t > & GetCandidateGateTensor() const
const Matrix_t & GetInputGateTensorAt(size_t i) const
std::vector< Matrix_t > & GetForgetGateTensor()
std::vector< Matrix_t > cell_value
cell value for every time step
Matrix_t & fWeightsOutputGradients
Gradients w.r.t the output gate - input weights.
Matrix_t & GetOutputGateBias()
Matrix_t & fOutputBiasGradients
Gradients w.r.t the output gate - bias weights.
DNN::EActivationFunction fF1
Activation function: sigmoid.
virtual void Initialize()
Initialize the weights according to the given initialization method.
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Matrix_t & fWeightsOutputGate
Output Gate weights for input, fWeights[6].
Matrix_t & GetForgetGateBias()
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
const Matrix_t & GetInputGateBias() const
typename Architecture_t::Scalar_t Scalar_t
size_t GetInputSize() const
Getters.
Matrix_t & GetForgetGateTensorAt(size_t i)
const Matrix_t & GetOutputGateTensorAt(size_t i) const
const Matrix_t & GetCellTensorAt(size_t i) const
Tensor_t fX
cached input tensor as T x B x I
DNN::EActivationFunction GetActivationFunctionF2() const
Matrix_t & GetCellTensorAt(size_t i)
Matrix_t & fWeightsInputStateGradients
Gradients w.r.t the input gate - hidden state weights.
void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues, const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
Forward for a single cell (time unit)
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, Matrix_t &cell_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &precCellActivations, const Matrix_t &input_gate, const Matrix_t &forget_gate, const Matrix_t &candidate_gate, const Matrix_t &output_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t)
Backward for a single time unit a the corresponding call to Forward(...).
const Matrix_t & GetWeightsInputStateGradients() const
std::vector< Matrix_t > fDerivativesOutput
First fDerivatives of the activations output gate.
size_t GetStateSize() const
Matrix_t & fWeightsForgetGateState
Forget Gate weights for prev state, fWeights[3].
Matrix_t & fOutputGateBias
Output Gate bias.
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
const Matrix_t & GetInputDerivativesAt(size_t i) const
Matrix_t & fWeightsForgetGate
Forget Gate weights for input, fWeights[2].
Matrix_t & fWeightsInputGradients
Gradients w.r.t the input gate - input weights.
typename Architecture_t::Tensor_t Tensor_t
const std::vector< Matrix_t > & GetDerivativesInput() const
Matrix_t & GetWeightsCandidate()
Matrix_t & fForgetGateBias
Forget Gate bias.
Matrix_t & GetWeightsInputGradients()
Matrix_t & GetCandidateBiasGradients()
Matrix_t & GetWeightsOutputGradients()
Matrix_t & fCandidateBias
Candidate Gate bias.
Matrix_t fCandidateValue
Computed candidate values.
Tensor_t & GetWeightGradientsTensor()
bool DoesRememberState() const
const Matrix_t & GetWeightsOutputGradients() const
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
const Matrix_t & GetWeightsInputGradients() const
Matrix_t & GetWeightsCandidateState()
Matrix_t & GetInputBiasGradients()
const Matrix_t & GetInputBiasGradients() const
size_t GetTimeSteps() const
DNN::EActivationFunction fF2
Activation function: tanh.
Matrix_t & fInputBiasGradients
Gradients w.r.t the input gate - bias weights.
Matrix_t & GetWeightsOutputStateGradients()
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Matrix_t & GetForgetGateValue()
std::vector< Matrix_t > fDerivativesForget
First fDerivatives of the activations forget gate.
const Tensor_t & GetWeightGradientsTensor() const
Matrix_t & GetForgetDerivativesAt(size_t i)
const Matrix_t & GetWeightsInputGateState() const
Matrix_t & GetWeightsInputStateGradients()
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Matrix_t & fForgetBiasGradients
Gradients w.r.t the forget gate - bias weights.
const Matrix_t & GetCandidateBias() const
std::vector< Matrix_t > output_gate_value
output gate value for every time step
const std::vector< Matrix_t > & GetDerivativesCandidate() const
size_t fStateSize
Hidden state size for LSTM.
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
std::vector< Matrix_t > fDerivativesInput
First fDerivatives of the activations input gate.
const Matrix_t & GetWeightsForgetGateState() const
Matrix_t & GetWeightsForgetGateState()
const Matrix_t & GetWeightsInputGate() const
const Matrix_t & GetInputGateValue() const
void Update(const Scalar_t learningRate)
bool DoesReturnSequence() const
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t & GetOutputGateValue()
TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Matrix_t & GetWeightsForgetStateGradients()
const Matrix_t & GetOutputBiasGradients() const
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
const Matrix_t & GetWeightsOutputStateGradients() const
Matrix_t & fWeightsOutputStateGradients
Gradients w.r.t the output gate - hidden state weights.
bool fReturnSequence
Return in output full sequence or just last element.
Matrix_t & GetWeightsForgetGradients()
Matrix_t & GetWeightsCandidateGradients()
const Matrix_t & GetWeightsForgetGradients() const
Matrix_t fCell
Cell state of LSTM.
std::vector< Matrix_t > & GetDerivativesCandidate()
const Matrix_t & GetForgetBiasGradients() const
std::vector< Matrix_t > & GetOutputGateTensor()
Matrix_t & GetCandidateValue()
const Matrix_t & GetForgetDerivativesAt(size_t i) const
Matrix_t fState
Hidden state of LSTM.
void OutputGate(const Matrix_t &input, Matrix_t &dout)
Computes output values (NN with Sigmoid)
const Matrix_t & GetForgetGateValue() const
std::vector< Matrix_t > candidate_gate_value
candidate gate value for every time step
Matrix_t & GetInputGateValue()
const Matrix_t & GetState() const
const Matrix_t & GetWeightsCandidateState() const
Matrix_t & GetCandidateBias()
const std::vector< Matrix_t > & GetForgetGateTensor() const
const std::vector< Matrix_t > & GetDerivativesOutput() const
const std::vector< Matrix_t > & GetCellTensor() const
const Tensor_t & GetWeightsTensor() const
Matrix_t & fWeightsInputGate
Input Gate weights for input, fWeights[0].
std::vector< Matrix_t > & GetCandidateGateTensor()
const Matrix_t & GetOutputDerivativesAt(size_t i) const
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
const Matrix_t & GetCell() const
Matrix_t & fWeightsForgetStateGradients
Gradients w.r.t the forget gate - hidden state weights.
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Matrix_t fOutputValue
Computed output gate values.
size_t fCellSize
Cell state size of LSTM.
Matrix_t & GetOutputDerivativesAt(size_t i)
Matrix_t & GetInputGateTensorAt(size_t i)
std::vector< Matrix_t > & GetDerivativesInput()
Matrix_t & fWeightsOutputGateState
Output Gate weights for prev state, fWeights[7].
const std::vector< Matrix_t > & GetDerivativesForget() const
Matrix_t & GetForgetBiasGradients()
const Matrix_t & GetForgetGateBias() const
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Matrix_t & GetInputGateBias()
Matrix_t & GetOutputGateTensorAt(size_t i)
size_t fTimeSteps
Timesteps for LSTM.
const Matrix_t & GetCandidateBiasGradients() const
const Matrix_t & GetCandidateValue() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Matrix_t & fInputGateBias
Input Gate bias.
const Matrix_t & GetWeightsForgetGate() const
std::vector< Matrix_t > input_gate_value
input gate value for every time step
const Matrix_t & GetWeightsCandidateStateGradients() const
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightsForgetGradients
Gradients w.r.t the forget gate - input weights.
std::vector< Matrix_t > & GetDerivativesForget()
const Matrix_t & GetWeightsOutputGate() const
void ForgetGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
std::vector< Matrix_t > & GetInputGateTensor()
Matrix_t & GetOutputBiasGradients()
const Matrix_t & GetOutputGateValue() const
const Matrix_t & GetWeightsOutputGateState() const
Matrix_t & GetCandidateDerivativesAt(size_t i)
Matrix_t fInputValue
Computed input gate values.
Matrix_t & GetWeightsOutputGate()
const Matrix_t & GetWeightsCandidate() const
void Print() const
Prints the info about the layer.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
const Matrix_t & GetWeightsCandidateGradients() const
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Matrix_t & GetInputDerivativesAt(size_t i)
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunctionF1() const
Tensor_t fY
cached output tensor as T x B x S
std::vector< Matrix_t > forget_gate_value
forget gate value for every time step
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightsInputGateState
Input Gate weights for prev state, fWeights[1].
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
std::vector< Matrix_t > & GetCellTensor()
size_t GetCellSize() const
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Matrix_t fForgetValue
Computed forget gate values.
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
EActivationFunction
Enum that represents layer activation functions.
create variable transformations