Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_GRU.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GRU
2#define TMVA_SOFIE_ROPERATOR_GRU
3
4#include "TMVA/RModel.hxx"
5#include "TMVA/ROperator.hxx"
7
8#include <memory>
9#include <sstream>
10#include <stdexcept>
11#include <string>
12#include <vector>
13
14namespace TMVA {
15namespace Experimental {
16namespace SOFIE {
17
18/*! \brief Gated Recurrent Unit operator
19 *
20 * Inference code generation for one-layer GRU. Supports forward, reverse and bidirectional GRU.
21 * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU">ONNX documentation</a>
22 * for details about the supported GRU architectures.
23 */
24template <typename T> class ROperator_GRU final : public ROperator {
25 private:
26 std::vector<float> fAttrActivationAlpha; ///< Scaling values used by some activation functions
27 std::vector<float> fAttrActivationBeta; ///< Scaling values used by some activation functions
28 std::vector<std::string> fAttrActivations; ///< Activation functions
29 float fAttrClip; ///< Clip threshold
30 std::string fAttrDirection; ///< Direction of processing
31 size_t fAttrHiddenSize; ///< Number of the hidden layers
32 size_t fAttrLayout; ///< Data layout
33 size_t fAttrLinearBeforeReset; ///< Linear layer before the reset gate
34
35 std::string fNX; ///< Name of the input
36 std::string fNW; ///< Name of the weights
37 std::string fNR; ///< Name of the recurrence
38 std::string fNB; ///< Name of the bias
39 std::string fNSequence_lens; ///< Name of the length of the sequences
40 std::string fNInitial_h; ///< Name of the initial value of the hidden states
41 std::string fNY; ///< Name of the output
42 std::string fNY_h; ///< Name of the last sequence of the output
43
44 std::vector<size_t> fShapeX; ///< Shape of the input
45 std::vector<size_t> fShapeW; ///< Shape of the weights
46 std::vector<size_t> fShapeR; ///< Shape of the recurrence
47 std::vector<size_t> fShapeB; ///< Shape of the bias
48 std::vector<size_t> fShapeSequence_lens; ///< Shape of the length of the sequences
49 std::vector<size_t> fShapeInitial_h; ///< Shape of the initial value of hidden states
50 std::vector<size_t> fShapeY; ///< Shape of the output
51 std::vector<size_t> fShapeY_h; ///< Shape of the last sequence of the output
52
53 std::string fType; ///< Type of the tensors
54
55 public:
56 /*! Default constructor of ROperator_GRU */
58
59 /*! \brief Constructor of ROperator_GRU from the attributes
60 *
61 * \param activation_alpha scaling values used by some activation functions
62 * \param activation_beta scaling values used by some activation functions
63 * \param activations activation functions
64 * \param clip clip threshold
65 * \param direction direction of processing of the sequneces
66 * \param hidden_size number of hidden layers
67 * \param layout data layout
68 * \param linear_before_reset Linear layer before the reset gate
69 * \param nameX name of the input tensor
70 * \param nameW name of the weight tensor
71 * \param nameR name of the recurrence tensor
72 * \param nameB name of the bias tensor
73 * \param nameSequence_lens name of the length of the sequences
74 * \param nameInitial_h name of the initial value of the hidden states
75 * \param nameY name of the output
76 * \param nameY_h name of the last sequence of the output
77 */
78 ROperator_GRU(std::vector<float> activation_alpha,
79 std::vector<float> activation_beta,
80 std::vector<std::string> activations, float clip,
81 std::string direction, size_t hidden_size,
82 size_t layout, size_t linear_before_reset,
83 std::string nameX, std::string nameW, std::string nameR,
84 std::string nameB, std::string nameSequence_lens,
85 std::string nameInitial_h, std::string nameY, std::string nameY_h)
90 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
91 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
92 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
93 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
94 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)) {
95
97 if (!fNB.empty()){
98 fInputTensorNames.emplace_back(fNB);
99 }
100 if (!fNSequence_lens.empty()){
102 }
103 if (!fNInitial_h.empty()){
104 fInputTensorNames.emplace_back(fNInitial_h);
105 }
106
107 fOutputTensorNames = { };
108 if (!fNY.empty()){
109 fOutputTensorNames.emplace_back(fNY);
110 }
111 if (!fNY_h.empty()){
112 fOutputTensorNames.emplace_back(fNY_h);
113 }
114
115 if (std::is_same<T, float>::value) {
116 fType = "float";
117 } else {
118 throw std::runtime_error(
119 "TMVA SOFIE Encountered unsupported type parsing a GRU operator");
120 }
121 }
122
123 /*! \brief Infers the type of the output tensors
124 *
125 * \param input type of the input tensors
126 */
127 std::vector<ETensorType> TypeInference(std::vector<ETensorType> /*input*/);
128
129 /*! \brief Infers the shape of the output tensors
130 *
131 * \param input shape of the input tensors
132 */
133 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> /*input*/);
134
135 /*! \brief Initialize the model
136 *
137 * \param model Model
138 */
139 void Initialize(RModel &);
140
141 /*! \brief Generate the inference code
142 *
143 * \param OpName name of the operator
144 */
145 std::string Generate(std::string /*OpName*/);
146
147 /*! \brief Generate the code for the Session internal data vectors
148 *
149 * \param opName name of the operator
150 */
151 std::string GenerateSessionMembersCode(std::string opName);
152
153 /*! \brief Returns the blas routines needed to compile the generated code
154 */
155 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
156};
157
158} // namespace SOFIE
159} // namespace Experimental
160} // namespace TMVA
161
162// Implementation of the ROperator_GRU class
163#include "TMVA/ROperator_GRU.icc"
164
165#endif
Gated Recurrent Unit operator.
std::string GenerateSessionMembersCode(std::string opName)
Generate the code for the Session internal data vectors.
std::vector< size_t > fShapeY
Shape of the output.
std::string fNX
Name of the input.
std::string fType
Type of the tensors.
std::string fAttrDirection
Direction of processing.
std::string fNR
Name of the recurrence.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::vector< std::string > GetBlasRoutines()
Returns the blas routines needed to compile the generated code.
std::string fNY
Name of the output.
std::string fNY_h
Name of the last sequence of the output.
std::string fNSequence_lens
Name of the length of the sequences.
std::vector< std::string > fAttrActivations
Activation functions.
ROperator_GRU(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t layout, size_t linear_before_reset, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameY, std::string nameY_h)
Constructor of ROperator_GRU from the attributes.
size_t fAttrHiddenSize
Number of the hidden layers.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
Infers the shape of the output tensors.
std::vector< float > fAttrActivationAlpha
Scaling values used by some activation functions.
std::vector< size_t > fShapeR
Shape of the recurrence.
void Initialize(RModel &)
Initialize the model.
std::string fNW
Name of the weights.
std::vector< size_t > fShapeX
Shape of the input.
std::string Generate(std::string)
Generate the inference code.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of hidden states.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
size_t fAttrLinearBeforeReset
Linear layer before the reset gate.
std::vector< size_t > fShapeB
Shape of the bias.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
Infers the type of the output tensors.
std::string fNInitial_h
Name of the initial value of the hidden states.
std::vector< size_t > fShapeW
Shape of the weights.
ROperator_GRU()
Default constructor of ROperator_GRU.
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:46
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:47
create variable transformations