Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_GRU.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GRU
2#define TMVA_SOFIE_ROPERATOR_GRU
3
4#include "TMVA/RModel.hxx"
5#include "TMVA/ROperator.hxx"
7
8#include <memory>
9#include <sstream>
10#include <stdexcept>
11#include <string>
12#include <vector>
13
15
16/*! \brief Gated Recurrent Unit operator
17 *
18 * Inference code generation for one-layer GRU. Supports forward, reverse and bidirectional GRU.
19 * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU">ONNX documentation</a>
20 * for details about the supported GRU architectures.
21 */
22template <typename T> class ROperator_GRU final : public ROperator {
23 private:
24 std::vector<float> fAttrActivationAlpha; ///< Scaling values used by some activation functions
25 std::vector<float> fAttrActivationBeta; ///< Scaling values used by some activation functions
26 std::vector<std::string> fAttrActivations; ///< Activation functions
27 float fAttrClip; ///< Clip threshold
28 std::string fAttrDirection; ///< Direction of processing
29 size_t fAttrHiddenSize; ///< Number of the hidden layers
30 size_t fAttrLayout; ///< Data layout
31 size_t fAttrLinearBeforeReset; ///< Linear layer before the reset gate
32
33 std::string fNX; ///< Name of the input
34 std::string fNW; ///< Name of the weights
35 std::string fNR; ///< Name of the recurrence
36 std::string fNB; ///< Name of the bias
37 std::string fNSequence_lens; ///< Name of the length of the sequences
38 std::string fNInitial_h; ///< Name of the initial value of the hidden states
39 std::string fNY; ///< Name of the output
40 std::string fNY_h; ///< Name of the last sequence of the output
41
42 std::vector<size_t> fShapeX; ///< Shape of the input
43 std::vector<size_t> fShapeW; ///< Shape of the weights
44 std::vector<size_t> fShapeR; ///< Shape of the recurrence
45 std::vector<size_t> fShapeB; ///< Shape of the bias
46 std::vector<size_t> fShapeSequence_lens; ///< Shape of the length of the sequences
47 std::vector<size_t> fShapeInitial_h; ///< Shape of the initial value of hidden states
48 std::vector<size_t> fShapeY; ///< Shape of the output
49 std::vector<size_t> fShapeY_h; ///< Shape of the last sequence of the output
50
51 std::string fType; ///< Type of the tensors
52
53 public:
54 /*! Default constructor of ROperator_GRU */
56
57 /*! \brief Constructor of ROperator_GRU from the attributes
58 *
59 * \param activation_alpha scaling values used by some activation functions
60 * \param activation_beta scaling values used by some activation functions
61 * \param activations activation functions
62 * \param clip clip threshold
63 * \param direction direction of processing of the sequneces
64 * \param hidden_size number of hidden layers
65 * \param layout data layout
66 * \param linear_before_reset Linear layer before the reset gate
67 * \param nameX name of the input tensor
68 * \param nameW name of the weight tensor
69 * \param nameR name of the recurrence tensor
70 * \param nameB name of the bias tensor
71 * \param nameSequence_lens name of the length of the sequences
72 * \param nameInitial_h name of the initial value of the hidden states
73 * \param nameY name of the output
74 * \param nameY_h name of the last sequence of the output
75 */
76 ROperator_GRU(std::vector<float> activation_alpha,
77 std::vector<float> activation_beta,
78 std::vector<std::string> activations, float clip,
79 std::string direction, size_t hidden_size,
80 size_t layout, size_t linear_before_reset,
81 std::string nameX, std::string nameW, std::string nameR,
82 std::string nameB, std::string nameSequence_lens,
83 std::string nameInitial_h, std::string nameY, std::string nameY_h)
88 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
89 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
90 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
91 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
92 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)) {
93
95 if (!fNB.empty()){
96 fInputTensorNames.emplace_back(fNB);
97 }
98 if (!fNSequence_lens.empty()){
100 }
101 if (!fNInitial_h.empty()){
102 fInputTensorNames.emplace_back(fNInitial_h);
103 }
104
105 fOutputTensorNames = { };
106 if (!fNY.empty()){
107 fOutputTensorNames.emplace_back(fNY);
108 }
109 if (!fNY_h.empty()){
110 fOutputTensorNames.emplace_back(fNY_h);
111 }
112
113 if (std::is_same<T, float>::value) {
114 fType = "float";
115 } else {
116 throw std::runtime_error(
117 "TMVA SOFIE Encountered unsupported type parsing a GRU operator");
118 }
119 }
120
121 /*! \brief Infers the type of the output tensors
122 *
123 * \param input type of the input tensors
124 */
125 std::vector<ETensorType> TypeInference(std::vector<ETensorType> /*input*/) override;
126
127 /*! \brief Infers the shape of the output tensors
128 *
129 * \param input shape of the input tensors
130 */
131 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> /*input*/) override;
132
133 /*! \brief Initialize the model
134 *
135 * \param model Model
136 */
137 void Initialize(RModel &) override;
138
139 /*! \brief Generate the inference code
140 *
141 * \param OpName name of the operator
142 */
143 std::string Generate(std::string /*OpName*/) override;
144
145 /*! \brief Returns the blas routines needed to compile the generated code
146 */
147 std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
148};
149
150template <typename T>
151auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input) -> std::vector<ETensorType>
152{
153 ETensorType out = input[0];
154 return {out, out};
155}
156
157template <typename T>
158auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input) -> std::vector<std::vector<size_t>>
159{
160 size_t num_directions = input[1][0];
161 size_t hidden_size = input[1][1] / 3;
162 if (fAttrLayout == 0) {
163 size_t seq_length = input[0][0];
164 size_t batch_size = input[0][1];
165 std::vector<std::vector<size_t>> ret(
167 return ret;
168 } else {
169 size_t batch_size = input[0][0];
170 size_t seq_length = input[0][1];
171 std::vector<std::vector<size_t>> ret(
173 return ret;
174 }
175}
176
177template <typename T>
179{
180
181 fUseSession = model.UseSession();
182 // Check the input and output tensors
183 if (!model.CheckIfTensorAlreadyExist(fNX)) {
184 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not found in model.");
185 }
186 fShapeX = model.GetTensorShape(fNX);
187 if (fShapeX.size() != 3) {
188 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
189 }
190 if (!model.CheckIfTensorAlreadyExist(fNW)) {
191 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not found in model.");
192 }
193 fShapeW = model.GetTensorShape(fNW);
194 if (fShapeW.size() != 3) {
195 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
196 }
197 if (!model.CheckIfTensorAlreadyExist(fNR)) {
198 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not found in model.");
199 }
200 fShapeR = model.GetTensorShape(fNR);
201 if (fShapeR.size() != 3) {
202 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
203 }
204 if (!fNB.empty()) {
205 if (!model.CheckIfTensorAlreadyExist(fNB)) {
206 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not found in model.");
207 }
208 fShapeB = model.GetTensorShape(fNB);
209 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
210 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
211 }
212 if (fShapeB.size() == 2) {
213 // Broadcasting the bias
214 auto original_data = model.GetInitializedTensorData(fNB);
215 size_t num_directions = fShapeW[0];
216 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
217 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
218 if (fType == "float") {
219 float *original_bias = static_cast<float *>(original_data.get());
220 float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
221 for (size_t direction = 0; direction < num_directions; direction++) {
222 for (size_t i = 0; i < 6; i++) {
223 for (size_t seq = 0; seq < seq_length; seq++) {
224 for (size_t batch = 0; batch < batch_size; batch++) {
225 size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
226 size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
227 i * batch_size * seq_length * fAttrHiddenSize +
228 +seq * batch_size * fAttrHiddenSize + batch * fAttrHiddenSize;
229 std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
230 new_bias + offset);
231 }
232 }
233 }
234 }
235
236 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
237 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
239 fShapeB = model.GetTensorShape(fNB);
240 }
241 }
242 }
243 if (!fNSequence_lens.empty()) {
244 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
245 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNSequence_lens + "is not found in model.");
246 }
247 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
248 if (fShapeSequence_lens.size() != 1) {
249 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNSequence_lens + " is not of 1 dimension.");
250 }
251 }
252 if (!fNInitial_h.empty()) {
253 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
254 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNInitial_h + " is not found in model.");
255 }
256 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
257 if (fShapeInitial_h.size() != 3) {
258 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNInitial_h + " is not of 3 dimensions.");
259 }
260 }
261 if (!fNY.empty()) {
262 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
263 if (!model.CheckIfTensorAlreadyExist(fNY)) {
264 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
265 }
266 }
267 if (!fNY_h.empty()) {
268 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
269 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
270 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
271 }
272 }
273 // Check the attributes
274 for (auto &activation : fAttrActivations) {
275 if (activation != "Relu" && activation != "Tanh" && activation != "Sigmoid" && activation != "Affine" &&
276 activation != "LeakyRelu" && activation != "ThresholdRelu" && activation != "ScaledTanh" &&
277 activation != "HardSigmoid" && activation != "Elu" && activation != "Softsign" && activation != "Softplus") {
278 throw std::runtime_error("TMVA SOFIE - Activation function " + activation + " not implemented");
279 }
280 }
281 if (fAttrDirection == "reverse")
282 fAttrDirection = "backward";
283 if (fAttrDirection != "forward" && fAttrDirection != "backward" && fAttrDirection != "reverse" &&
284 fAttrDirection != "bidirectional") {
285 throw std::runtime_error("TMVA SOFIE - Invalid GRU direction fAttrDirection = " + fAttrDirection);
286 }
287 if (3 * fAttrHiddenSize != fShapeW[1]) {
288 throw std::runtime_error("TMVA SOFIE - fAttrHiddenSize must be equal to " + std::to_string(fShapeW[1] / 3));
289 }
290 if (fAttrLayout > 1) {
291 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " + std::to_string(fAttrLayout) +
292 " must be 0 (timewise) or 1 (batchwise)");
293 }
294 if (fAttrLinearBeforeReset > 1) {
295 throw std::runtime_error("TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset) +
296 " must be 0 or 1.");
297 }
298 if (fAttrActivations.empty()) {
299 if (fAttrDirection == "bidirectional") {
300 fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
301 } else {
302 fAttrActivations = {"Sigmoid", "Tanh"};
303 }
304 }
305
306 // To get unique intermediate tensor names, we add the name of the input
307 // tensor. One might also consider using the index of the operator in the
308 // RMode, but this information is not available in the current scope.
309 std::string opName = "op_gru_" + fNX;
310
311 size_t num_directions = fShapeW[0];
312 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
313 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
314 size_t input_size = fShapeX[2];
315
316 auto declareVector = [&](std::string const &name, std::size_t n) {
317 std::string fullName = opName + "_" + name;
318 model.AddIntermediateTensor(fullName, ConvertStringToType(fType), std::vector<std::size_t>{n});
319 };
320
321 if (fAttrLayout != 0) {
323 declareVector("initial_hidden_state", num_directions * batch_size * fAttrHiddenSize);
324 declareVector("initial_cell_state", num_directions * batch_size * fAttrHiddenSize);
325 }
326 // Set the feedforward
327 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
328 declareVector("f_update_gate", ff_size);
329 declareVector("f_reset_gate", ff_size);
330 declareVector("f_hidden_gate", ff_size);
331 // gate results
332 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
333 declareVector("update_gate", hs_size);
334 declareVector("reset_gate", hs_size);
335 declareVector("hidden_gate", hs_size);
336
337 // feedback
338 declareVector("feedback", batch_size * fAttrHiddenSize);
339
340 // hiddden state
341 if (fAttrLayout != 0 || fNY.empty()) {
342 declareVector("hidden_state", hs_size);
343 }
344}
345
346template <typename T>
347auto ROperator_GRU<T>::Generate(std::string OpName) -> std::string
348{
349 OpName = "op_" + OpName;
350 std::stringstream out;
351
352 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
353 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
354 size_t input_size = fShapeX[2];
355 size_t num_directions = fShapeW[0];
356
357 auto getVec = [&](std::string const &name) { return "tensor_op_gru_" + fNX + "_" + name; };
358
359 // set the input
360 if (fAttrLayout == 0) {
361 out << SP << fType << " const* " << OpName << "_input = tensor_" << fNX << ";\n";
362 } else {
363 if (fUseSession) {
364 out << SP << fType << " * " << OpName << "_input = " << getVec("input") << ";\n";
365 } else {
366 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
367 }
368 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
369 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
370 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
371 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size << " + batch * " << input_size
372 << " + i] = " << "tensor_" << fNX << "[batch * " << seq_length * input_size << " + seq * " << input_size
373 << " + i];\n";
374 out << SP << SP << SP << "}\n";
375 out << SP << SP << "}\n";
376 out << SP << "}\n";
377 }
378
379 // Set the initial hidden state
380 if (!fNInitial_h.empty()) {
381 if (fAttrLayout == 0) {
382 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_" << fNInitial_h << ";\n";
383 } else {
384 if (fUseSession) {
385 out << SP << fType << " * " << OpName << "_initial_hidden_state = " << getVec("initial_hidden_state")
386 << ";\n";
387 } else {
388 out << SP << fType << " " << OpName << "_initial_hidden_state["
389 << num_directions * batch_size * fAttrHiddenSize << "];\n";
390 }
391 for (size_t direction = 0; direction < num_directions; direction++) {
392 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
393 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
394 out << SP << SP << SP << OpName << "_initial_hidden_state[" << direction * batch_size * fAttrHiddenSize
395 << " + batch * " << fAttrHiddenSize << " + h] = tensor_" << fNInitial_h << "[batch * "
396 << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << " + h];\n";
397 out << SP << SP << "}\n";
398 out << SP << "}\n";
399 }
400 }
401 }
402
403 // Set the feedforward
404 size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
405 if (fUseSession) {
406 out << SP << fType << " * " << OpName << "_f_update_gate = " << getVec("f_update_gate") << ";\n";
407 out << SP << fType << " * " << OpName << "_f_reset_gate = " << getVec("f_reset_gate") << ";\n";
408 out << SP << fType << " * " << OpName << "_f_hidden_gate = " << getVec("f_hidden_gate") << ";\n";
409 } else {
410 out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
411 out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
412 out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
413 }
414 // Set the gates
415 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
416 if (fUseSession) {
417 out << SP << fType << " * " << OpName << "_update_gate = " << getVec("update_gate") << ";\n";
418 out << SP << fType << " * " << OpName << "_reset_gate = " << getVec("reset_gate") << ";\n";
419 out << SP << fType << " * " << OpName << "_hidden_gate = " << getVec("hidden_gate") << ";\n";
420 } else {
421 out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
422 out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
423 out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
424 }
425 // Set the hidden state
426 if (fAttrLayout == 0 && !fNY.empty()) {
427 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
428 } else {
429 if (fUseSession) {
430 out << SP << fType << " * " << OpName << "_hidden_state = " << getVec("hidden_state") << ";\n";
431 } else {
432 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
433 }
434 }
435
436 if (fUseSession) {
437 out << SP << fType << " * " << OpName << "_feedback = " << getVec("feedback") << ";\n";
438 } else {
439 out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
440 }
441
442 out << SP << "char " << OpName << "_transA = 'N';\n";
443 out << SP << "char " << OpName << "_transB = 'T';\n";
444 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
445 out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
446 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
447 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
448 if (fType == "float") {
449 out << SP << "float " << OpName << "_alpha = 1.;\n";
450 out << SP << "float " << OpName << "_beta = 0.;\n";
451 }
452 if (!fNB.empty()) {
453 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
454 }
455 out << SP << "int " << OpName << "_incx = 1;\n";
456 out << SP << "int " << OpName << "_incy = 1;\n";
457 out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
458
459 for (size_t direction = 0; direction < num_directions; direction++) {
460 if (direction == 0) {
461 if (fType == "float") {
462 // f_update_gate = input * weight_z^T
463 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
464 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << ", &" << OpName
465 << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, " << OpName
466 << "_f_update_gate, &" << OpName << "_n);\n";
467 // f_reset_gate = input * weight_r^T
468 size_t wr_offset = fAttrHiddenSize * input_size;
469 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
470 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wr_offset
471 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
472 << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
473 // f_hidden_gate = input * weight_h^T
474 size_t wh_offset = 2 * fAttrHiddenSize * input_size;
475 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
476 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wh_offset
477 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
478 << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
479 }
480 } else {
481 if (fType == "float") {
482 // f_update_gate = input * weight_z^T
483 size_t wz_offset = 3 * fAttrHiddenSize * input_size;
484 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
485 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wz_offset
486 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
487 << OpName << "_f_update_gate, &" << OpName << "_n);\n";
488 // f_reset_gate = input * weight_r^T
489 size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
490 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
491 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wr_offset
492 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
493 << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
494 // f_hidden_gate = input * weight_h^T
495 size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
496 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
497 << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_" << fNW << " + " << wh_offset
498 << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &" << OpName << "_beta, "
499 << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
500 }
501 }
502
503 if (!fNB.empty()) {
504 if (direction == 0) {
505 if (fType == "float") {
506 // Add the bias of the weight to f_update_gate
507 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << ", &"
508 << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
509 // Add the bias of the recurrence to f_update_gate
510 size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
511 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
512 << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName
513 << "_incy);\n";
514 // Add the bias of the weight to f_reset_gate
515 size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
516 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
517 << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &" << OpName
518 << "_incy);\n";
519 // Add the bias of the recurrence to f_reset_gate
520 // size_t rbr_offset = fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
521 size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
522 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
523 << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &" << OpName
524 << "_incy);\n";
525 // Add the bias of the weight to f_hidden_gate
526 size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
527 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
528 << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &" << OpName
529 << "_incy);\n";
530 if (fAttrLinearBeforeReset == 0) {
531 // Add the bias of the recurrence to f_hidden_gate
532 size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
533 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB
534 << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &" << OpName
535 << "_incy);\n";
536 }
537 }
538 } else {
539 if (fType == "float") {
540 // Add the bias of the weight to f_update_gate
541 size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
542 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
543 << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName
544 << "_incy);\n";
545 // Add the bias of the recurrence to f_update_gate
546 // size_t rbz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
547 size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
548 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
549 << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName
550 << "_incy);\n";
551 // Add the bias of the weight to f_reset_gate
552 size_t wbr_offset = 7 * batch_size * seq_length * fAttrHiddenSize;
553 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
554 << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &" << OpName
555 << "_incy);\n";
556 // Add the bias of the recurrence to f_reset_gate
557 size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
558 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
559 << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &" << OpName
560 << "_incy);\n";
561 // Add the bias of the weight to f_hidden_gate
562 size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
563 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB << " + "
564 << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &" << OpName
565 << "_incy);\n";
566 if (fAttrLinearBeforeReset == 0) {
567 // Add the bias of the recurrence to f_hidden_gate
568 size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
569 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_" << fNB
570 << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &" << OpName
571 << "_incy);\n";
572 }
573 }
574 }
575 }
576
577 // Copy the feedforward into the gates
578 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
579 out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
580 if (direction == 0) {
581 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize << ";\n";
582 } else {
583 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize << " + "
584 << batch_size * fAttrHiddenSize << ";\n";
585 }
586 size_t f_seq_size = batch_size * fAttrHiddenSize;
587 out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName << "_f_update_gate + offset + "
588 << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
589 out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName << "_f_reset_gate + offset + "
590 << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
591 out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName << "_f_hidden_gate + offset + "
592 << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
593 out << SP << "}\n";
594
595 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
596 if (fAttrDirection == "backward" || direction == 1) {
597 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
598 } else {
599 out << SP << SP << "size_t index = seq;\n";
600 }
601 out << SP << SP << "int m2 = " << batch_size << ";\n";
602 if (direction == 0) {
603 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize << ";\n";
604 } else {
605 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize << " + "
606 << batch_size * fAttrHiddenSize << ";\n";
607 }
608 size_t size = batch_size * fAttrHiddenSize;
609 // gate = gate + initial_hidden_state * Recurrence^T
610 out << SP << SP << "if (seq == 0) {\n";
611 if (!fNInitial_h.empty()) {
612 if (direction == 0) {
613 if (fType == "float") {
614 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
615 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &" << OpName
616 << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName << "_alpha, "
617 << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
618 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
619 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
620 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rr_offset
621 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
622 << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
623 }
624 } else { // direction=1
625 if (fType == "float") {
626 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
627 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
628 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rz_offset
629 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
630 << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
631 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
632 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
633 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rr_offset
634 << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
635 << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
636 }
637 }
638 }
639 out << SP << SP << "} else {\n";
640 // gate = gate + previous_hidden_state * Recurrence^T
641 if (direction == 0) {
642 if (fAttrDirection == "backward") {
643 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
644 << num_directions * batch_size * fAttrHiddenSize << ";\n";
645 } else {
646 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
647 << num_directions * batch_size * fAttrHiddenSize << ";\n";
648 }
649 if (fType == "float") {
650 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
651 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &" << OpName << "_n, "
652 << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &" << OpName << "_alpha, " << OpName
653 << "_update_gate + offset, &" << OpName << "_n);\n";
654 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
655 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
656 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rr_offset
657 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
658 << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
659 }
660 } else {
661 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
662 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
663 if (fType == "float") {
664 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
665 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
666 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rz_offset
667 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
668 << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
669 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
670 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
671 << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rr_offset
672 << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
673 << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
674 }
675 }
676 out << SP << SP << "}\n";
677
678 // Clip the elements of the update gate and the reset gate into the range [-fClip, fClip]
679 if (fAttrClip > .0) {
680 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
681 if (fType == "float") {
682 out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip << ") ? " << OpName
683 << "_update_gate[i] : " << -fAttrClip << ";\n";
684 }
685 out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip << ") ? z : " << fAttrClip << ";\n";
686 if (fType == "float") {
687 out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip << ") ? " << OpName
688 << "_reset_gate[i] : " << -fAttrClip << ";\n";
689 }
690 out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip << ") ? r : " << fAttrClip << ";\n";
691 out << SP << SP << "}\n";
692 }
693
694 // Apply the activation function to the update gate and the reset gate
695 if (fAttrActivations[direction * 2] == "Relu") {
696 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
697 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
698 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
699 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
700 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
701 out << SP << SP << "}\n";
702 } else if (fAttrActivations[direction * 2] == "Tanh") {
703 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
704 if (fType == "float") {
705 out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
706 }
707 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
708 if (fType == "float") {
709 out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
710 }
711 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
712 out << SP << SP << "}\n";
713 } else if (fAttrActivations[direction * 2] == "Sigmoid") {
714 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
715 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-" << OpName
716 << "_update_gate[i]));\n";
717 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-" << OpName
718 << "_reset_gate[i]));\n";
719 out << SP << SP << "}\n";
720 } else if (fAttrActivations[direction * 2] == "Affine") {
721 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
722 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << fAttrActivationAlpha[direction * 2] << " * "
723 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
724 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << fAttrActivationAlpha[direction * 2] << " * "
725 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
726 out << SP << SP << "}\n";
727 } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
728 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
729 if (fType == "float") {
730 out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2] << " * " << OpName
731 << "_update_gate[i]);\n";
732 }
733 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << fAttrActivationAlpha[direction * 2]
734 << " * (1. - z) / (1. + z);\n";
735 if (fType == "float") {
736 out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2] << " * " << OpName
737 << "_reset_gate[i]);\n";
738 }
739 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << fAttrActivationAlpha[direction * 2]
740 << " * (1. - r) / (1. + r);\n";
741 out << SP << SP << "}\n";
742 } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
743 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
744 if (fType == "float") {
745 out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * " << OpName
746 << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
747 out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
748 }
749 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
750 if (fType == "float") {
751 out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * " << OpName
752 << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
753 out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
754 }
755 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
756 out << SP << SP << "}\n";
757 } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
758 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
759 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
760 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << fAttrActivationAlpha[direction * 2] << " * "
761 << OpName << "_update_gate[i];\n";
762 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
763 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << fAttrActivationAlpha[direction * 2] << " * "
764 << OpName << "_reset_gate[i];\n";
765 out << SP << SP << "}\n";
766 } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
767 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
768 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < " << fAttrActivationAlpha[direction * 2]
769 << ")\n";
770 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
771 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < " << fAttrActivationAlpha[direction * 2]
772 << ")\n";
773 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
774 out << SP << SP << "}";
775 } else if (fAttrActivations[direction * 2] == "Elu") {
776 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
777 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
778 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << fAttrActivationAlpha[direction * 2]
779 << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
780 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
781 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << fAttrActivationAlpha[direction * 2]
782 << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
783 out << SP << SP << "}\n";
784 } else if (fAttrActivations[direction * 2] == "Softsign") {
785 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
786 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName << "_update_gate[i] / (1. + abs("
787 << OpName << "_update_gate[i]));\n";
788 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName << "_reset_gate[i] / (1. + abs("
789 << OpName << "_reset_gate[i]));\n";
790 out << SP << SP << "}\n";
791 } else { // fAttrActivations[direction * 2] = Softplus
792 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
793 out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp(" << OpName << "_update_gate[i]));\n";
794 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp(" << OpName << "_reset_gate[i]));\n";
795 out << SP << SP << "}\n";
796 }
797
798 if (fAttrLinearBeforeReset == 0) {
799 out << SP << SP << "if (seq == 0) {\n";
800 if (!fNInitial_h.empty()) {
801 // feedback = reset_gate o initial_hidden_state
802 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
803 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName << "_reset_gate[i + offset] * "
804 << OpName << "_initial_hidden_state[i];\n";
805 out << SP << SP << SP << "}\n";
806 }
807 out << SP << SP << "} else {\n";
808 // feedback = reset_gate o previous_hidden_state
809 if (direction == 0) {
810 if (fAttrDirection == "backward") {
811 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
812 << num_directions * batch_size * fAttrHiddenSize << ";\n";
813 } else {
814 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
815 << num_directions * batch_size * fAttrHiddenSize << ";\n";
816 }
817 } else {
818 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
819 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
820 }
821 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
822 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName << "_reset_gate[i + offset] * " << OpName
823 << "_hidden_state[i + previous_offset];\n";
824 out << SP << SP << SP << "}\n";
825 out << SP << SP << "}\n";
826 // feedback = feedback * R_h^T
827 size_t rh_offset = (direction == 0)
828 ? 2 * fAttrHiddenSize * fAttrHiddenSize
829 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
830 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
831 << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + " << rh_offset
832 << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName << "_n, &" << OpName << "_beta, "
833 << OpName << "_feedback, &" << OpName << "_n);\n";
834 } else { // fAttrLinearBeforeReset=1
835 // feedback = previous_hidden_state * R_h^T
836 // LM fixes
837 size_t rh_offset = (direction == 0)
838 ? 2 * fAttrHiddenSize * fAttrHiddenSize
839 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
840 out << SP << SP << "if (seq == 0) {\n";
841 if (!fNInitial_h.empty()) {
842 // feedback = W * initial_hidden_state + bias
843 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
844 << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
845 << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &"
846 << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
847 }
848 out << SP << SP << "} else {\n";
849 // case for seq > 0
850 if (direction == 0) {
851 if (fAttrDirection == "backward") {
852 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
853 << num_directions * batch_size * fAttrHiddenSize << ";\n";
854 } else {
855 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
856 << num_directions * batch_size * fAttrHiddenSize << ";\n";
857 }
858 } else {
859 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
860 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
861 }
862 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName
863 << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
864 << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName
865 << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
866 // endif on seq 0 or not
867 out << SP << SP << "}\n";
868 // Add the bias of the recurrence to feedback
869 if (!fNB.empty()) {
870 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
871 : 11 * batch_size * seq_length * fAttrHiddenSize;
872 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, tensor_" << fNB
873 << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName << "_feedback, &" << OpName
874 << "_incy);\n";
875 }
876 // feedback = reset_gate o feedback
877 out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
878 out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
879 out << SP << SP << "}\n";
880 }
881
882 // hidden_gate = hidden_gate + feedback
883 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, " << OpName
884 << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &" << OpName << "_incy);\n";
885
886 // Clip the elements of the hidden gate into the range [-fClip, fClip]
887 if (fAttrClip > .0) {
888 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
889 if (fType == "float") {
890 out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip << ") ? " << OpName
891 << "_hidden_gate[i] : " << -fAttrClip << ";\n";
892 }
893 out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : " << fAttrClip << ";\n";
894 out << SP << SP << "}\n";
895 }
896
897 // Apply the activation function to the hidden gate
898 if (fAttrActivations[direction * 2 + 1] == "Relu") {
899 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
900 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
901 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
902 out << SP << SP << "}\n";
903 } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
904 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
905 if (fType == "float") {
906 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
907 }
908 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
909 out << SP << SP << "}\n";
910 } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
911 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
912 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
913 << "_hidden_gate[i]));\n";
914 out << SP << SP << "}\n";
915 } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
916 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
917 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << fAttrActivationAlpha[direction * 2 + 1]
918 << " * " << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
919 out << SP << SP << "}\n";
920 } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
921 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
922 if (fType == "float") {
923 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1] << " * " << OpName
924 << "_hidden_gate[i]);\n";
925 }
926 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << fAttrActivationAlpha[direction * 2 + 1]
927 << " * (1. - ex) / (1. + ex);\n";
928 out << SP << SP << "}\n";
929 } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
930 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
931 if (fType == "float") {
932 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName
933 << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
934 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
935 }
936 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
937 out << SP << SP << "}\n";
938 } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
939 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
940 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
941 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << fAttrActivationAlpha[direction * 2 + 1]
942 << " * " << OpName << "_hidden_gate[i];\n";
943 out << SP << SP << "}\n";
944 } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
945 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
946 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < " << fAttrActivationAlpha[direction * 2 + 1]
947 << ")\n";
948 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
949 out << SP << SP << "}";
950 } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
951 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
952 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
953 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << fAttrActivationAlpha[direction * 2 + 1]
954 << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
955 out << SP << SP << "}\n";
956 } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
957 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
958 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName << "_hidden_gate[i] / (1. + abs("
959 << OpName << "_hidden_gate[i]));\n";
960 out << SP << SP << "}\n";
961 } else { // fAttrActivations[direction * 2 + 1] = Softplus
962 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
963 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp(" << OpName << "_hidden_gate[i]));\n";
964 out << SP << SP << "}\n";
965 }
966
967 // hidden_state = (1 - update_gate) o hidden_gate
968 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
969 out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName << "_update_gate[i]) * " << OpName
970 << "_hidden_gate[i];\n";
971 out << SP << SP << "}\n";
972
973 out << SP << SP << "if (seq == 0) {\n";
974 if (!fNInitial_h.empty()) {
975 // hidden_state += update_gate o initial_hidden_state
976 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
977 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
978 << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
979 out << SP << SP << SP << "}\n";
980 }
981 out << SP << SP << "} else {\n";
982 // hidden_state += update_gate o previous_hidden_state
983 if (direction == 0) {
984 if (fAttrDirection == "backward") {
985 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
986 << num_directions * batch_size * fAttrHiddenSize << ";\n";
987 } else {
988 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
989 << num_directions * batch_size * fAttrHiddenSize << ";\n";
990 }
991 } else {
992 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
993 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
994 }
995 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
996 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
997 << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
998 out << SP << SP << SP << "}\n";
999 out << SP << SP << "}\n";
1000
1001 out << SP << "}\n";
1002 }
1003
1004 // Padding the hidden state for GRU with different sequence lengths
1005 if (!fNSequence_lens.empty()) {
1006 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
1007 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1008 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
1009 for (size_t direction = 0; direction < num_directions; direction++) {
1010 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
1011 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
1012 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
1013 << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
1014 out << SP << SP << SP << SP << SP << "}\n";
1015 }
1016 out << SP << SP << SP << "}\n";
1017 out << SP << SP << "}\n";
1018 out << SP << "}\n";
1019 }
1020
1021 // Copy the hidden state into y and y_h
1022 if (fAttrLayout == 0) {
1023 if (!fNY_h.empty()) {
1024 // Copy hidden_state into Y_h
1025 if (fNSequence_lens.empty()) {
1026 size_t yh_size = batch_size * fAttrHiddenSize;
1027 if (fAttrDirection == "backward") {
1028 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + " << yh_size
1029 << ", tensor_" << fNY_h << ");\n";
1030 } else {
1031 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
1032 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
1033 << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
1034 }
1035 if (num_directions == 2) {
1036 out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
1037 << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
1038 }
1039 } else { // GRU with different sequence lengths
1040 if (fAttrDirection == "backward") {
1041 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1042 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1043 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1044 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
1045 out << SP << "}\n";
1046 } else {
1047 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1048 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1049 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1050 << " + batch * " << fAttrHiddenSize << ";\n";
1051 out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
1052 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1053 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1054 out << SP << "}\n";
1055 }
1056 if (num_directions == 2) {
1057 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1058 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1059 << ";\n";
1060 out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize << " + batch * "
1061 << fAttrHiddenSize << ";\n";
1062 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1063 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1064 out << SP << "}\n";
1065 }
1066 }
1067 }
1068 } else { // fAttrLayout=1
1069 if (!fNY.empty()) {
1070 // Copy hidden_state into Y
1071 for (size_t direction = 0; direction < num_directions; direction++) {
1072 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
1073 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1074 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize << " + "
1075 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
1076 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
1077 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
1078 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1079 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
1080 out << SP << SP << "}\n";
1081 out << SP << "}\n";
1082 }
1083 }
1084 if (!fNY_h.empty()) {
1085 // Copy the hidden_state into Y_h
1086 if (fAttrDirection == "backward") {
1087 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1088 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1089 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1090 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1091 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1092 out << SP << "}\n";
1093 } else {
1094 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1095 if (fNSequence_lens.empty()) {
1096 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
1097 } else {
1098 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1099 }
1100 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1101 << " + batch * " << fAttrHiddenSize << ";\n";
1102 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1103 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1104 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1105 out << SP << "}\n";
1106 }
1107 if (num_directions == 2) {
1108 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1109 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
1110 << ";\n";
1111 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1112 << fAttrHiddenSize << ";\n";
1113 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1114 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1115 out << SP << "}\n";
1116 }
1117 }
1118 }
1119
1120 return out.str();
1121}
1122
1123} // namespace TMVA::Experimental::SOFIE
1124
1125#endif
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
char name[80]
Definition TGX11.cxx:110
std::vector< size_t > GetTensorShape(const std::string &name) const
Definition RModel.cxx:29
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:262
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:327
ETensorType GetTensorType(std::string name) const
Definition RModel.cxx:90
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:318
Gated Recurrent Unit operator.
std::vector< size_t > fShapeY
Shape of the output.
std::string fNX
Name of the input.
std::string fType
Type of the tensors.
std::string fAttrDirection
Direction of processing.
std::string fNR
Name of the recurrence.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::string fNY
Name of the output.
std::string fNY_h
Name of the last sequence of the output.
std::string fNSequence_lens
Name of the length of the sequences.
std::vector< std::string > fAttrActivations
Activation functions.
void Initialize(RModel &) override
Initialize the model.
ROperator_GRU(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t layout, size_t linear_before_reset, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameY, std::string nameY_h)
Constructor of ROperator_GRU from the attributes.
size_t fAttrHiddenSize
Number of the hidden layers.
std::string Generate(std::string) override
Generate the inference code.
std::vector< float > fAttrActivationAlpha
Scaling values used by some activation functions.
std::vector< size_t > fShapeR
Shape of the recurrence.
std::string fNW
Name of the weights.
std::vector< size_t > fShapeX
Shape of the input.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >) override
Infers the shape of the output tensors.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of hidden states.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >) override
Infers the type of the output tensors.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
size_t fAttrLinearBeforeReset
Linear layer before the reset gate.
std::vector< size_t > fShapeB
Shape of the bias.
std::string fNInitial_h
Name of the initial value of the hidden states.
std::vector< size_t > fShapeW
Shape of the weights.
ROperator_GRU()
Default constructor of ROperator_GRU.
std::vector< std::string > GetBlasRoutines() override
Returns the blas routines needed to compile the generated code.
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:49
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:50
const Int_t n
Definition legend1.C:16
ETensorType ConvertStringToType(std::string type)