29#ifndef TMVA_DNN_RNN_LAYER
30#define TMVA_DNN_RNN_LAYER
55template<
typename Architecture_t>
61 using Tensor_t =
typename Architecture_t::Tensor_t;
62 using Matrix_t =
typename Architecture_t::Matrix_t;
63 using Scalar_t =
typename Architecture_t::Scalar_t;
112 TBasicRNNLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
113 size_t timeSteps,
bool rememberState =
false,
bool returnSequence =
false,
141 const Tensor_t &activations_backward);
149 const Matrix_t & precStateActivations,
212template <
typename Architecture_t>
217 :
VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1 ,
218 stateSize, 2, {stateSize, stateSize}, {inputSize, stateSize}, 1, {stateSize}, {1},
219 batchSize, (returnSequence) ? timeSteps : 1, stateSize, fA),
220 fTimeSteps(timeSteps), fStateSize(stateSize), fRememberState(rememberState), fReturnSequence(returnSequence), fF(
f), fState(batchSize, stateSize),
221 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)),
222 fBiases(this->GetBiasesAt(0)), fDerivatives(timeSteps, batchSize, stateSize),
223 fWeightInputGradients(this->GetWeightGradientsAt(0)), fWeightStateGradients(this->GetWeightGradientsAt(1)),
224 fBiasGradients(this->GetBiasGradientsAt(0)), fWeightsTensor({0}), fWeightGradientsTensor({0})
230template <
typename Architecture_t>
232 :
VGeneralLayer<Architecture_t>(layer), fTimeSteps(layer.fTimeSteps), fStateSize(layer.fStateSize),
233 fRememberState(layer.fRememberState), fReturnSequence(layer.fReturnSequence), fF(layer.GetActivationFunction()),
234 fState(layer.GetBatchSize(), layer.GetStateSize()),
235 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)), fBiases(this->GetBiasesAt(0)),
236 fDerivatives(layer.GetDerivatives().GetShape()), fWeightInputGradients(this->GetWeightGradientsAt(0)),
237 fWeightStateGradients(this->GetWeightGradientsAt(1)), fBiasGradients(this->GetBiasGradientsAt(0)),
238 fWeightsTensor({0}), fWeightGradientsTensor({0})
241 Architecture_t::Copy(fDerivatives, layer.GetDerivatives() );
244 Architecture_t::Copy(fState, layer.GetState());
248template <
typename Architecture_t>
252 Architecture_t::ReleaseRNNDescriptors(fDescriptors);
257 Architecture_t::FreeRNNWorkspace(fWorkspace);
263template<
typename Architecture_t>
273 Architecture_t::InitializeRNNDescriptors(fDescriptors,
this);
274 Architecture_t::InitializeRNNWorkspace(fWorkspace, fDescriptors,
this);
278template <
typename Architecture_t>
282 Architecture_t::InitializeRNNTensors(
this);
285template <
typename Architecture_t>
290 Architecture_t::InitializeActivationDescriptor(fActivationDesc,this->GetActivationFunction());
294template<
typename Architecture_t>
298 std::cout <<
" RECURRENT Layer: \t ";
299 std::cout <<
" (NInput = " << this->GetInputSize();
300 std::cout <<
", NState = " << this->GetStateSize();
301 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
302 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" )\n";
305template <
typename Architecture_t>
306auto debugMatrix(
const typename Architecture_t::Matrix_t &A,
const std::string
name =
"matrix")
309 std::cout <<
name <<
"\n";
310 for (
size_t i = 0; i < A.GetNrows(); ++i) {
311 for (
size_t j = 0; j < A.GetNcols(); ++j) {
312 std::cout << A(i, j) <<
" ";
316 std::cout <<
"********\n";
321template <
typename Architecture_t>
327 if (Architecture_t::IsCudnn()) {
332 Architecture_t::Rearrange(
x,
input);
337 const auto & weights = this->GetWeightsTensor();
343 auto &hx = this->GetState();
344 auto &cx = this->GetCell();
346 auto &hy = this->GetState();
347 auto &cy = this->GetCell();
354 Architecture_t::RNNForward(
x, hx, cx, weights,
y, hy, cy, rnnDesc, rnnWork, isTraining);
356 if (fReturnSequence) {
357 Architecture_t::Rearrange(this->GetOutput(),
y);
361 Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1,
y.GetShape()[2]});
362 Architecture_t::Copy(this->GetOutput(), tmp);
373 Tensor_t arrInput (fTimeSteps, this->GetBatchSize(), this->GetInputWidth() );
375 Architecture_t::Rearrange(arrInput,
input);
376 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
381 for (
size_t t = 0; t < fTimeSteps; ++t) {
382 Matrix_t arrInput_m = arrInput.At(t).GetMatrix();
383 Matrix_t df_m = fDerivatives.At(t).GetMatrix();
384 CellForward(arrInput_m, df_m );
385 Matrix_t arrOutput_m = arrOutput.At(t).GetMatrix();
386 Architecture_t::Copy(arrOutput_m, fState);
390 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
394 Tensor_t tmp = arrOutput.At(fTimeSteps - 1);
397 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
398 assert(tmp.GetSize() == this->GetOutput().GetSize());
399 assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
400 Architecture_t::Rearrange(this->GetOutput(), tmp);
407template <
typename Architecture_t>
413 Matrix_t tmpState(fState.GetNrows(), fState.GetNcols());
414 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsState);
415 Architecture_t::MultiplyTranspose(fState,
input, fWeightsInput);
416 Architecture_t::ScaleAdd(fState, tmpState);
417 Architecture_t::AddRowWise(fState, fBiases);
424 Architecture_t::Copy(inputActivFunc, tState);
425 Architecture_t::ActivationFunctionForward(tState, fAF, fActivationDesc);
430template <
typename Architecture_t>
432 const Tensor_t &activations_backward) ->
void
437 if (Architecture_t::IsCudnn() ) {
445 assert(activations_backward.GetStrides()[1] == this->GetInputSize() );
447 Architecture_t::Rearrange(
x, activations_backward);
449 if (!fReturnSequence) {
452 Architecture_t::InitializeZero(dy);
455 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
458 Architecture_t::Copy(tmp2, this->GetActivationGradients());
461 Architecture_t::Rearrange(
y, this->GetOutput());
462 Architecture_t::Rearrange(dy, this->GetActivationGradients());
469 auto &weights = this->GetWeightsTensor();
470 auto &weightGradients = this->GetWeightGradientsTensor();
473 Architecture_t::InitializeZero(weightGradients);
477 auto &hx = this->GetState();
478 auto &cx = this->GetCell();
489 Architecture_t::RNNBackward(
x, hx, cx,
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
491 if (gradients_backward.GetSize() != 0)
492 Architecture_t::Rearrange(gradients_backward, dx);
505 if (gradients_backward.GetSize() == 0) {
508 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
515 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
517 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
519 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
522 Matrix_t initState(this->GetBatchSize(), fStateSize);
525 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
526 Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
528 if (fReturnSequence) {
529 Architecture_t::Rearrange(arr_output, this->GetOutput());
530 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
535 Architecture_t::InitializeZero(arr_actgradients);
537 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape({this->GetBatchSize(), fStateSize, 1});
538 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
539 assert(tmp_grad.GetShape()[0] ==
540 this->GetActivationGradients().GetShape()[2]);
542 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
546 fWeightInputGradients.Zero();
547 fWeightStateGradients.Zero();
548 fBiasGradients.Zero();
550 for (
size_t t = fTimeSteps; t > 0; t--) {
552 Matrix_t actgrad_m = arr_actgradients.At(t - 1).GetMatrix();
553 Architecture_t::ScaleAdd(state_gradients_backward, actgrad_m);
555 Matrix_t actbw_m = arr_activations_backward.At(t - 1).GetMatrix();
556 Matrix_t gradbw_m = arr_gradients_backward.At(t - 1).GetMatrix();
563 Architecture_t::ActivationFunctionBackward(df,
y,
565 this->GetActivationFunction(), fActivationDesc);
571 Matrix_t precStateActivations = arr_output.At(t - 2).GetMatrix();
572 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
575 const Matrix_t & precStateActivations = initState;
576 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
581 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
586template <
typename Architecture_t>
588 const Matrix_t & precStateActivations,
592 return Architecture_t::RecurrentLayerBackward(state_gradients_backward, fWeightInputGradients, fWeightStateGradients,
593 fBiasGradients, dF, precStateActivations, fWeightsInput,
594 fWeightsState,
input, input_gradient);
598template <
typename Architecture_t>
611 this->WriteMatrixToXML(layerxml,
"InputWeights",
this -> GetWeightsAt(0));
612 this->WriteMatrixToXML(layerxml,
"StateWeights",
this -> GetWeightsAt(1));
613 this->WriteMatrixToXML(layerxml,
"Biases",
this -> GetBiasesAt(0));
619template <
typename Architecture_t>
623 this->ReadMatrixXML(parent,
"InputWeights",
this -> GetWeightsAt(0));
624 this->ReadMatrixXML(parent,
"StateWeights",
this -> GetWeightsAt(1));
625 this->ReadMatrixXML(parent,
"Biases",
this -> GetBiasesAt(0));
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
size_t GetStateSize() const
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunction() const
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the state method.
const Matrix_t & GetWeightInputGradients() const
const Tensor_t & GetWeightGradientsTensor() const
void Print() const
Prints the info about the layer.
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Tensor_t fY
cached output tensor as T x B x S
Tensor_t fDerivatives
First fDerivatives of the activations.
const Matrix_t & GetWeightStateGradients() const
Matrix_t & fWeightsInput
Input weights, fWeights[0].
Matrix_t & fWeightsState
Prev state weights, fWeights[1].
virtual ~TBasicRNNLayer()
Destructor.
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
Tensor_t fX
cached input tensor as T x B x I
void Forward(Tensor_t &input, bool isTraining=true)
Compute and return the next state with given input matrix.
Matrix_t & fBiases
Biases.
Architecture_t::ActivationDescriptor_t fActivationDesc
virtual void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fReturnSequence
Return in output full sequence or just last element in time.
const Tensor_t & GetWeightsTensor() const
Matrix_t & GetBiasStateGradients()
Tensor_t fWeightGradientsTensor
size_t fStateSize
Hidden state size of RNN.
Matrix_t & GetWeightsState()
const Matrix_t & GetState() const
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
Tensor_t & GetDerivatives()
const Matrix_t & GetCell() const
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dF)
Backward for a single time unit a the corresponding call to Forward(...).
typename Architecture_t::Matrix_t Matrix_t
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fState
Hidden State.
Matrix_t & fWeightInputGradients
Gradients w.r.t. the input weights.
DNN::EActivationFunction fF
Activation function of the hidden state.
TBasicRNNLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Tensor_t & GetWeightGradientsTensor()
Matrix_t & GetWeightsInput()
size_t GetTimeSteps() const
Getters.
bool fRememberState
Remember state in next pass.
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightStateGradients
Gradients w.r.t. the recurring weights.
Matrix_t & GetWeightInputGradients()
const Matrix_t & GetBiasesState() const
void Update(const Scalar_t learningRate)
virtual void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
typename Architecture_t::Scalar_t Scalar_t
size_t fTimeSteps
Timesteps for RNN.
bool DoesRememberState() const
void CellForward(const Matrix_t &input, Matrix_t &dF)
Forward for a single cell (time unit)
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
typename Architecture_t::Tensor_t Tensor_t
const Matrix_t & GetBiasStateGradients() const
size_t GetInputSize() const
Matrix_t & GetWeightStateGradients()
bool DoesReturnSequence() const
Matrix_t & fBiasGradients
Gradients w.r.t. the bias values.
const Matrix_t & GetWeightsInput() const
Matrix_t fCell
Empty matrix for RNN.
const Tensor_t & GetDerivatives() const
Matrix_t & GetBiasesState()
virtual void Initialize()
Initialize the weights according to the given initialization method.
const Matrix_t & GetWeightsState() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
auto debugMatrix(const typename Architecture_t::Matrix_t &A, const std::string name="matrix") -> void
EActivationFunction
Enum that represents layer activation functions.
create variable transformations