30#ifndef TMVA_DNN_GRU_LAYER 
   31#define TMVA_DNN_GRU_LAYER 
   55template<
typename Architecture_t>
 
   61   using Matrix_t = 
typename Architecture_t::Matrix_t;
 
   62   using Scalar_t = 
typename Architecture_t::Scalar_t;
 
   63   using Tensor_t = 
typename Architecture_t::Tensor_t;
 
  191   void Print() 
const override;
 
 
  307template <
typename Architecture_t>
 
  319     fCandidateValue(batchSize, 
stateSize), fState(batchSize, 
stateSize), fWeightsResetGate(
this->GetWeightsAt(0)),
 
  320     fWeightsResetGateState(
this->GetWeightsAt(3)), fResetGateBias(
this->GetBiasesAt(0)),
 
  321     fWeightsUpdateGate(
this->GetWeightsAt(1)), fWeightsUpdateGateState(
this->GetWeightsAt(4)),
 
  322     fUpdateGateBias(
this->GetBiasesAt(1)), fWeightsCandidate(
this->GetWeightsAt(2)),
 
  323     fWeightsCandidateState(
this->GetWeightsAt(5)), fCandidateBias(
this->GetBiasesAt(2)),
 
  324     fWeightsResetGradients(
this->GetWeightGradientsAt(0)), fWeightsResetStateGradients(
this->GetWeightGradientsAt(3)),
 
  325     fResetBiasGradients(
this->GetBiasGradientsAt(0)), fWeightsUpdateGradients(
this->GetWeightGradientsAt(1)),
 
  326     fWeightsUpdateStateGradients(
this->GetWeightGradientsAt(4)), fUpdateBiasGradients(
this->GetBiasGradientsAt(1)),
 
  327     fWeightsCandidateGradients(
this->GetWeightGradientsAt(2)),
 
  328     fWeightsCandidateStateGradients(
this->GetWeightGradientsAt(5)),
 
  329     fCandidateBiasGradients(
this->GetBiasGradientsAt(2))
 
  339   Architecture_t::InitializeGRUTensors(
this);
 
 
  343template <
typename Architecture_t>
 
  346      fStateSize(
layer.fStateSize),
 
  347      fTimeSteps(
layer.fTimeSteps),
 
  348      fRememberState(
layer.fRememberState),
 
  349      fReturnSequence(
layer.fReturnSequence),
 
  350      fResetGateAfter(
layer.fResetGateAfter),
 
  351      fF1(
layer.GetActivationFunctionF1()),
 
  352      fF2(
layer.GetActivationFunctionF2()),
 
  353      fResetValue(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  354      fUpdateValue(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  355      fCandidateValue(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  356      fState(
layer.GetBatchSize(), 
layer.GetStateSize()),
 
  357      fWeightsResetGate(
this->GetWeightsAt(0)),
 
  358      fWeightsResetGateState(
this->GetWeightsAt(3)),
 
  359      fResetGateBias(
this->GetBiasesAt(0)),
 
  360      fWeightsUpdateGate(
this->GetWeightsAt(1)),
 
  361      fWeightsUpdateGateState(
this->GetWeightsAt(4)),
 
  362      fUpdateGateBias(
this->GetBiasesAt(1)),
 
  363      fWeightsCandidate(
this->GetWeightsAt(2)),
 
  364      fWeightsCandidateState(
this->GetWeightsAt(5)),
 
  365      fCandidateBias(
this->GetBiasesAt(2)),
 
  366      fWeightsResetGradients(
this->GetWeightGradientsAt(0)),
 
  367      fWeightsResetStateGradients(
this->GetWeightGradientsAt(3)),
 
  368      fResetBiasGradients(
this->GetBiasGradientsAt(0)),
 
  369      fWeightsUpdateGradients(
this->GetWeightGradientsAt(1)),
 
  370      fWeightsUpdateStateGradients(
this->GetWeightGradientsAt(4)),
 
  371      fUpdateBiasGradients(
this->GetBiasGradientsAt(1)),
 
  372      fWeightsCandidateGradients(
this->GetWeightGradientsAt(2)),
 
  373      fWeightsCandidateStateGradients(
this->GetWeightGradientsAt(5)),
 
  374      fCandidateBiasGradients(
this->GetBiasGradientsAt(2))
 
  404   Architecture_t::InitializeGRUTensors(
this);
 
 
  408template <
typename Architecture_t>
 
  413   Architecture_t::InitializeGRUDescriptors(fDescriptors, 
this);
 
  414   Architecture_t::InitializeGRUWorkspace(fWorkspace, fDescriptors, 
this);
 
  417   if (Architecture_t::IsCudnn())
 
  418      fResetGateAfter = 
true;
 
 
  422template <
typename Architecture_t>
 
  431   Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsResetGateState);
 
  432   Architecture_t::MultiplyTranspose(fResetValue, 
input, fWeightsResetGate);
 
  433   Architecture_t::ScaleAdd(fResetValue, 
tmpState);
 
  434   Architecture_t::AddRowWise(fResetValue, fResetGateBias);
 
  435   DNN::evaluateDerivativeMatrix<Architecture_t>(
dr, 
fRst, fResetValue);
 
  436   DNN::evaluateMatrix<Architecture_t>(fResetValue, 
fRst);
 
 
  440template <
typename Architecture_t>
 
  449   Architecture_t::MultiplyTranspose(
tmpState, fState, fWeightsUpdateGateState);
 
  450   Architecture_t::MultiplyTranspose(fUpdateValue, 
input, fWeightsUpdateGate);
 
  451   Architecture_t::ScaleAdd(fUpdateValue, 
tmpState);
 
  452   Architecture_t::AddRowWise(fUpdateValue, fUpdateGateBias);
 
  453   DNN::evaluateDerivativeMatrix<Architecture_t>(
du, 
fUpd, fUpdateValue);
 
  454   DNN::evaluateMatrix<Architecture_t>(fUpdateValue, 
fUpd);
 
 
  458template <
typename Architecture_t>
 
  475   Matrix_t tmp(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
 
  476   if (!fResetGateAfter) {
 
  478      Architecture_t::Hadamard(
tmpState, fState);
 
  479      Architecture_t::MultiplyTranspose(tmp, 
tmpState, fWeightsCandidateState);
 
  482      Architecture_t::MultiplyTranspose(tmp, fState, fWeightsCandidateState);
 
  483      Architecture_t::Hadamard(tmp, fResetValue);
 
  485   Architecture_t::MultiplyTranspose(fCandidateValue, 
input, fWeightsCandidate);
 
  486   Architecture_t::ScaleAdd(fCandidateValue, tmp);
 
  487   Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
 
  488   DNN::evaluateDerivativeMatrix<Architecture_t>(
dc, fCan, fCandidateValue);
 
  489   DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
 
 
  493template <
typename Architecture_t>
 
  498   if (Architecture_t::IsCudnn()) {
 
  505      Architecture_t::Rearrange(
x, 
input);
 
  508      const auto &weights = this->GetWeightsTensor();
 
  510      auto &
hx = this->fState;
 
  511      auto &
cx = this->fCell;
 
  513      auto &
hy = this->fState;
 
  514      auto &
cy = this->fCell;
 
  521      if (fReturnSequence) {
 
  522         Architecture_t::Rearrange(this->GetOutput(), 
y); 
 
  525         Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1, 
y.GetShape()[2]});
 
  526         Architecture_t::Copy(this->GetOutput(), tmp);
 
  537   Tensor_t arrInput ( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
 
  548   if (!this->fRememberState) {
 
  554   for (
size_t t = 0; t < fTimeSteps; ++t) {
 
  556      ResetGate(
arrInput[t], fDerivativesReset[t]);
 
  557      Architecture_t::Copy(this->GetResetGateTensorAt(t), fResetValue);
 
  558      UpdateGate(
arrInput[t], fDerivativesUpdate[t]);
 
  559      Architecture_t::Copy(this->GetUpdateGateTensorAt(t), fUpdateValue);
 
  561      CandidateValue(
arrInput[t], fDerivativesCandidate[t]);
 
  562      Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
 
  565      CellForward(fUpdateValue, fCandidateValue);
 
  574      Architecture_t::Rearrange(this->GetOutput(), 
arrOutput); 
 
  580      tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
 
  581      assert(tmp.GetSize() == 
this->GetOutput().GetSize());
 
  582      assert(tmp.GetShape()[0] == 
this->GetOutput().GetShape()[2]); 
 
  583      Architecture_t::Rearrange(this->GetOutput(), tmp);
 
 
  590template <
typename Architecture_t>
 
  598   for (
size_t j = 0; 
j < (size_t) tmp.GetNcols(); 
j++) {
 
  599      for (
size_t i = 0; i < (size_t) tmp.GetNrows(); i++) {
 
  600         tmp(i,
j) = 1 - tmp(i,
j);
 
 
  610template <
typename Architecture_t>
 
  616   if (Architecture_t::IsCudnn()) {
 
  629      if (!fReturnSequence) {
 
  632         Architecture_t::InitializeZero(
dy);
 
  638         Architecture_t::Copy(
tmp2, this->GetActivationGradients());
 
  640         Architecture_t::Rearrange(
y, this->GetOutput());
 
  641         Architecture_t::Rearrange(
dy, this->GetActivationGradients());
 
  647      const auto &weights = this->GetWeightsTensor();
 
  655      auto &
hx = this->GetState();
 
  656      auto &
cx = this->GetCell();
 
  666      Architecture_t::RNNBackward(
x, 
hx, 
cx, 
y, 
dy, 
dhy, 
dcy, weights, 
dx, 
dhx, 
dcx, 
weightGradients, 
rnnDesc, 
rnnWork);
 
  707   if (fReturnSequence) {
 
  708      Architecture_t::Rearrange(
arr_output, this->GetOutput());
 
  709      Architecture_t::Rearrange(
arr_actgradients, this->GetActivationGradients());
 
  718             this->GetActivationGradients().GetShape()[2]); 
 
  720      Architecture_t::Rearrange(
tmp_grad, this->GetActivationGradients());
 
  727   fWeightsResetGradients.Zero();
 
  728   fWeightsResetStateGradients.Zero();
 
  729   fResetBiasGradients.Zero();
 
  732   fWeightsUpdateGradients.Zero();
 
  733   fWeightsUpdateStateGradients.Zero();
 
  734   fUpdateBiasGradients.Zero();
 
  737   fWeightsCandidateGradients.Zero();
 
  738   fWeightsCandidateStateGradients.Zero();
 
  739   fCandidateBiasGradients.Zero();
 
  742   for (
size_t t = fTimeSteps; t > 0; t--) {
 
  750                      this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
 
  751                      this->GetCandidateGateTensorAt(t-1),
 
  753                      fDerivativesReset[t-1], fDerivativesUpdate[t-1],
 
  754                      fDerivativesCandidate[t-1]);
 
  759                      this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
 
  760                      this->GetCandidateGateTensorAt(t-1),
 
  762                      fDerivativesReset[t-1], fDerivativesUpdate[t-1],
 
  763                      fDerivativesCandidate[t-1]);
 
 
  775template <
typename Architecture_t>
 
  787                                           fWeightsResetGradients, fWeightsUpdateGradients, fWeightsCandidateGradients,
 
  788                                           fWeightsResetStateGradients, fWeightsUpdateStateGradients,
 
  789                                           fWeightsCandidateStateGradients, fResetBiasGradients, fUpdateBiasGradients,
 
  790                                           fCandidateBiasGradients, 
dr, 
du, 
dc,
 
  793                                           fWeightsResetGate, fWeightsUpdateGate, fWeightsCandidate,
 
  794                                           fWeightsResetGateState, fWeightsUpdateGateState, fWeightsCandidateState,
 
 
  800template <
typename Architecture_t>
 
  808template<
typename Architecture_t>
 
  812   std::cout << 
" GRU Layer: \t ";
 
  813   std::cout << 
" (NInput = " << this->GetInputSize();  
 
  814   std::cout << 
", NState = " << this->GetStateSize();  
 
  815   std::cout << 
", NTime  = " << this->GetTimeSteps() << 
" )";  
 
  816   std::cout << 
"\tOutput = ( " << this->GetOutput().GetFirstSize() << 
" , " << this->GetOutput()[0].GetNrows() << 
" , " << this->GetOutput()[0].GetNcols() << 
" )\n";
 
 
  820template <
typename Architecture_t>
 
  835   this->WriteMatrixToXML(
layerxml, 
"ResetWeights", this->GetWeightsAt(0));
 
  836   this->WriteMatrixToXML(
layerxml, 
"ResetStateWeights", this->GetWeightsAt(1));
 
  837   this->WriteMatrixToXML(
layerxml, 
"ResetBiases", this->GetBiasesAt(0));
 
  838   this->WriteMatrixToXML(
layerxml, 
"UpdateWeights", this->GetWeightsAt(2));
 
  839   this->WriteMatrixToXML(
layerxml, 
"UpdateStateWeights", this->GetWeightsAt(3));
 
  840   this->WriteMatrixToXML(
layerxml, 
"UpdateBiases", this->GetBiasesAt(1));
 
  841   this->WriteMatrixToXML(
layerxml, 
"CandidateWeights", this->GetWeightsAt(4));
 
  842   this->WriteMatrixToXML(
layerxml, 
"CandidateStateWeights", this->GetWeightsAt(5));
 
  843   this->WriteMatrixToXML(
layerxml, 
"CandidateBiases", this->GetBiasesAt(2));
 
 
  847template <
typename Architecture_t>
 
  852   this->ReadMatrixXML(parent, 
"ResetWeights", this->GetWeightsAt(0));
 
  853   this->ReadMatrixXML(parent, 
"ResetStateWeights", this->GetWeightsAt(1));
 
  854   this->ReadMatrixXML(parent, 
"ResetBiases", this->GetBiasesAt(0));
 
  855   this->ReadMatrixXML(parent, 
"UpdateWeights", this->GetWeightsAt(2));
 
  856   this->ReadMatrixXML(parent, 
"UpdateStateWeights", this->GetWeightsAt(3));
 
  857   this->ReadMatrixXML(parent, 
"UpdateBiases", this->GetBiasesAt(1));
 
  858   this->ReadMatrixXML(parent, 
"CandidateWeights", this->GetWeightsAt(4));
 
  859   this->ReadMatrixXML(parent, 
"CandidateStateWeights", this->GetWeightsAt(5));
 
  860   this->ReadMatrixXML(parent, 
"CandidateBiases", this->GetBiasesAt(2));
 
 
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
 
const Matrix_t & GetWeightsCandidate() const
 
Matrix_t & GetWeightsCandidateStateGradients()
 
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
 
void Forward(Tensor_t &input, bool isTraining=true) override
Computes the next hidden state and next cell state with given input matrix.
 
Matrix_t & GetWeightsResetGate()
 
Matrix_t & fResetBiasGradients
Gradients w.r.t the reset gate - bias weights.
 
std::vector< Matrix_t > & GetUpdateGateTensor()
 
typename Architecture_t::Tensor_t Tensor_t
 
std::vector< Matrix_t > reset_gate_value
Reset gate value for every time step.
 
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &reset_gate, const Matrix_t &update_gate, const Matrix_t &candidate_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dr, Matrix_t &du, Matrix_t &dc)
Backward for a single time unit a the corresponding call to Forward(...).
 
size_t fStateSize
Hidden state size for GRU.
 
const Matrix_t & GetWeightsResetGradients() const
 
const Matrix_t & GetUpdateBiasGradients() const
 
bool fReturnSequence
Return in output full sequence or just last element.
 
const Matrix_t & GetWeightsResetStateGradients() const
 
std::vector< Matrix_t > fDerivativesReset
First fDerivatives of the activations reset gate.
 
const Tensor_t & GetWeightsTensor() const
 
std::vector< Matrix_t > & GetResetGateTensor()
 
Matrix_t & GetWeightsUpdateGateState()
 
const std::vector< Matrix_t > & GetCandidateGateTensor() const
 
const Matrix_t & GetUpdateDerivativesAt(size_t i) const
 
Matrix_t & GetWeightsUpdateStateGradients()
 
void Print() const override
Prints the info about the layer.
 
size_t GetInputSize() const
Getters.
 
Matrix_t fState
Hidden state of GRU.
 
Matrix_t & GetWeightsResetGradients()
 
Tensor_t & GetWeightGradientsTensor()
 
const Matrix_t & GetCandidateBias() const
 
std::vector< Matrix_t > update_gate_value
Update gate value for every time step.
 
Tensor_t & GetWeightsTensor()
 
Tensor_t fX
cached input tensor as T x B x I
 
Matrix_t & GetCandidateGateTensorAt(size_t i)
 
Matrix_t & GetResetBiasGradients()
 
Matrix_t & GetCandidateValue()
 
void AddWeightsXMLTo(void *parent) override
Writes the information and the weights about the layer in an XML node.
 
Matrix_t & GetWeightsResetGateState()
 
DNN::EActivationFunction fF1
Activation function: sigmoid.
 
const Matrix_t & GetWeightsUpdateGate() const
 
const std::vector< Matrix_t > & GetDerivativesReset() const
 
const Matrix_t & GetUpdateGateBias() const
 
Matrix_t & fWeightsResetGradients
Gradients w.r.t the reset gate - input weights.
 
std::vector< Matrix_t > & GetDerivativesUpdate()
 
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
 
Matrix_t & fCandidateBias
Candidate Gate bias.
 
Matrix_t & GetUpdateGateTensorAt(size_t i)
 
DNN::EActivationFunction fF2
Activation function: tanh.
 
const Matrix_t & GetWeightsUpdateGradients() const
 
Matrix_t & GetWeightsCandidateGradients()
 
Matrix_t & fWeightsUpdateStateGradients
Gradients w.r.t the update gate - hidden state weights.
 
Matrix_t & GetWeightsCandidate()
 
Matrix_t & fWeightsUpdateGradients
Gradients w.r.t the update gate - input weights.
 
Matrix_t & GetUpdateGateValue()
 
size_t fTimeSteps
Timesteps for GRU.
 
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
 
const Tensor_t & GetWeightGradientsTensor() const
 
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
 
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
 
Matrix_t & fUpdateBiasGradients
Gradients w.r.t the update gate - bias weights.
 
Matrix_t & GetWeightsResetStateGradients()
 
std::vector< Matrix_t > & GetCandidateGateTensor()
 
Matrix_t & fWeightsResetGate
Reset Gate weights for input, fWeights[0].
 
const Matrix_t & GetResetDerivativesAt(size_t i) const
 
Matrix_t & GetWeightsUpdateGate()
 
typename Architecture_t::Matrix_t Matrix_t
 
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
 
Matrix_t & GetWeightsCandidateState()
 
const Matrix_t & GetCandidateBiasGradients() const
 
Matrix_t & GetResetGateTensorAt(size_t i)
 
Matrix_t & fResetGateBias
Input Gate bias.
 
const std::vector< Matrix_t > & GetResetGateTensor() const
 
Matrix_t fCell
Empty matrix for GRU.
 
std::vector< Matrix_t > candidate_gate_value
Candidate gate value for every time step.
 
typename Architecture_t::Scalar_t Scalar_t
 
const Matrix_t & GetWeigthsUpdateStateGradients() const
 
const Matrix_t & GetCandidateValue() const
 
Matrix_t & GetCandidateBiasGradients()
 
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
 
const std::vector< Matrix_t > & GetDerivativesUpdate() const
 
const Matrix_t & GetCell() const
 
void Initialize() override
Initialize the weights according to the given initialization method.
 
void UpdateGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
 
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
 
Matrix_t fResetValue
Computed reset gate values.
 
DNN::EActivationFunction GetActivationFunctionF2() const
 
Matrix_t & GetResetGateBias()
 
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
 
Matrix_t fUpdateValue
Computed forget gate values.
 
const Matrix_t & GetResetBiasGradients() const
 
bool fResetGateAfter
GRU variant to Apply the reset gate multiplication afterwards (used by cuDNN)
 
const Matrix_t & GetWeightsCandidateGradients() const
 
DNN::EActivationFunction GetActivationFunctionF1() const
 
bool DoesReturnSequence() const
 
Matrix_t & GetUpdateBiasGradients()
 
const Matrix_t & GetUpdateGateTensorAt(size_t i) const
 
Matrix_t & fWeightsResetGateState
Input Gate weights for prev state, fWeights[1].
 
Matrix_t & fWeightsUpdateGateState
Update Gate weights for prev state, fWeights[3].
 
const std::vector< Matrix_t > & GetDerivativesCandidate() const
 
Tensor_t fWeightsTensor
Tensor for all weights.
 
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
 
const Matrix_t & GetResetGateBias() const
 
Matrix_t & GetResetDerivativesAt(size_t i)
 
const Matrix_t & GetUpdateGateValue() const
 
const Matrix_t & GetResetGateTensorAt(size_t i) const
 
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
 
void CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues)
Forward for a single cell (time unit)
 
Matrix_t & GetWeightsUpdateGradients()
 
Matrix_t & fWeightsResetStateGradients
Gradients w.r.t the reset gate - hidden state weights.
 
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
 
size_t GetStateSize() const
 
std::vector< Matrix_t > & GetDerivativesReset()
 
Matrix_t & fUpdateGateBias
Update Gate bias.
 
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward) override
Backpropagates the error.
 
const Matrix_t & GetWeightsCandidateStateGradients() const
 
void ResetGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
 
const Matrix_t & GetWeightsResetGate() const
 
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
 
Matrix_t & GetCandidateBias()
 
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
 
bool fRememberState
Remember state in next pass.
 
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
 
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
 
const Matrix_t & GetWeightsCandidateState() const
 
void ReadWeightsFromXML(void *parent) override
Read the information and the weights about the layer from XML node.
 
const std::vector< Matrix_t > & GetUpdateGateTensor() const
 
const Matrix_t & GetResetGateValue() const
 
void Update(const Scalar_t learningRate)
 
Tensor_t fY
cached output tensor as T x B x S
 
Matrix_t fCandidateValue
Computed candidate values.
 
const Matrix_t & GetState() const
 
Matrix_t & GetUpdateGateBias()
 
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
 
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
 
Matrix_t & GetCandidateDerivativesAt(size_t i)
 
std::vector< Matrix_t > fDerivativesUpdate
First fDerivatives of the activations update gate.
 
size_t GetTimeSteps() const
 
const Matrix_t & GetWeightsUpdateGateState() const
 
std::vector< Matrix_t > & GetDerivativesCandidate()
 
bool DoesRememberState() const
 
const Matrix_t & GetWeightsResetGateState() const
 
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
 
TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, bool resetGateAfter=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
 
Matrix_t & GetUpdateDerivativesAt(size_t i)
 
Matrix_t & fWeightsUpdateGate
Update Gate weights for input, fWeights[2].
 
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
 
Matrix_t & GetResetGateValue()
 
Generic General Layer class.
 
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
 
size_t GetInputWidth() const
 
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
 
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
 
EActivationFunction
Enum that represents layer activation functions.
 
create variable transformations