Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_GRU.icc
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GRU_I
2#define TMVA_SOFIE_ROPERATOR_GRU_I
3
4namespace TMVA {
5namespace Experimental {
6namespace SOFIE {
7
8template <typename T>
9auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input)
10-> std::vector<ETensorType> {
11 ETensorType out = input[0];
12 return {out, out};
13}
14
15template<typename T>
16auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input)
17-> std::vector<std::vector<size_t>> {
18 size_t num_directions = input[1][0];
19 size_t hidden_size = input[1][1] / 3;
20 if (fAttrLayout == 0) {
21 size_t seq_length = input[0][0];
22 size_t batch_size = input[0][1];
23 std::vector<std::vector<size_t>> ret(
26 return ret;
27 } else {
28 size_t batch_size = input[0][0];
29 size_t seq_length = input[0][1];
30 std::vector<std::vector<size_t>> ret(
33 return ret;
34 }
35}
36
37template<typename T>
39
40 fUseSession = model.UseSession();
41 // Check the input and output tensors
42 if (!model.CheckIfTensorAlreadyExist(fNX)) {
43 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not found in model.");
44 }
45 fShapeX = model.GetTensorShape(fNX);
46 if (fShapeX.size() != 3) {
47 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
48 }
49 if (!model.CheckIfTensorAlreadyExist(fNW)) {
50 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not found in model.");
51 }
52 fShapeW = model.GetTensorShape(fNW);
53 if (fShapeW.size() != 3) {
54 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
55 }
56 if (!model.CheckIfTensorAlreadyExist(fNR)) {
57 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not found in model.");
58 }
59 fShapeR = model.GetTensorShape(fNR);
60 if (fShapeR.size() != 3) {
61 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
62 }
63 if (!fNB.empty()) {
64 if (!model.CheckIfTensorAlreadyExist(fNB)) {
65 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not found in model.");
66 }
67 fShapeB = model.GetTensorShape(fNB);
68 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
69 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
70 }
71 if (fShapeB.size() == 2) {
72 // Broadcasting the bias
74 size_t num_directions = fShapeW[0];
75 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
76 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
77 if (fType == "float") {
78 float *original_bias = static_cast<float*>(original_data.get());
79 float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
80 for (size_t direction = 0; direction < num_directions; direction++) {
81 for (size_t i = 0; i < 6; i++) {
82 for (size_t seq = 0; seq < seq_length; seq++) {
83 for (size_t batch = 0; batch < batch_size; batch++) {
84 size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
85 size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
86 i * batch_size * seq_length * fAttrHiddenSize +
87 + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
88 std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
90 }
91 }
92 }
93 }
94
95 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
96 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
98 fShapeB = model.GetTensorShape(fNB);
99 }
100 }
101 }
102 if (!fNSequence_lens.empty()) {
103 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
104 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
105 fNSequence_lens +
106 "is not found in model.");
107 }
108 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
109 if (fShapeSequence_lens.size() != 1) {
110 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
111 fNSequence_lens +
112 " is not of 1 dimension.");
113 }
114 }
115 if (!fNInitial_h.empty()) {
116 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
117 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
118 fNInitial_h + " is not found in model.");
119 }
120 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
121 if (fShapeInitial_h.size() != 3) {
122 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
123 fNInitial_h + " is not of 3 dimensions.");
124 }
125 }
126 if (!fNY.empty()) {
127 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
128 if (!model.CheckIfTensorAlreadyExist(fNY)) {
129 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
130 }
131 }
132 if (!fNY_h.empty()) {
133 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
134 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
135 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
136 }
137 }
138 // Check the attributes
139 for (auto &activation : fAttrActivations) {
140 if (activation != "Relu" && activation != "Tanh" &&
141 activation != "Sigmoid" && activation != "Affine" &&
142 activation != "LeakyRelu" && activation != "ThresholdRelu" &&
143 activation != "ScaledTanh" && activation != "HardSigmoid" &&
144 activation != "Elu" && activation != "Softsign" &&
145 activation != "Softplus") {
146 throw std::runtime_error("TMVA SOFIE - Activation function " +
147 activation + " not implemented");
148 }
149 }
150 if (fAttrDirection == "reverse") fAttrDirection = "backward";
151 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
152 fAttrDirection != "reverse" &&
153 fAttrDirection != "bidirectional") {
154 throw std::runtime_error(
155 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
156 fAttrDirection);
157 }
158 if (3 * fAttrHiddenSize != fShapeW[1]) {
159 throw std::runtime_error(
160 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
161 std::to_string(fShapeW[1] / 3));
162 }
163 if (fAttrLayout > 1) {
164 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
165 std::to_string(fAttrLayout) +
166 " must be 0 (timewise) or 1 (batchwise)");
167 }
168 if (fAttrLinearBeforeReset > 1) {
169 throw std::runtime_error(
170 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
171 + " must be 0 or 1.");
172 }
173 if (fAttrActivations.empty()) {
174 if (fAttrDirection == "bidirectional") {
175 fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
176 } else {
177 fAttrActivations = {"Sigmoid", "Tanh"};
178 }
179 }
180
181 // To get unique intermediate tensor names, we add the name of the input
182 // tensor. One might also consider using the index of the operator in the
183 // RMode, but this information is not available in the current scope.
184 std::string opName = "op_gru_" + fNX;
185
186 size_t num_directions = fShapeW[0];
187 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
188 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
189 size_t input_size = fShapeX[2];
190
191 auto declareVector = [&](std::string const &name, std::size_t n){
192 std::string fullName = opName + "_" + name;
193 model.AddIntermediateTensor(fullName, ConvertStringToType(fType), std::vector<std::size_t>{n});
194 };
195
196 if (fAttrLayout != 0) {
198 declareVector("initial_hidden_state", num_directions * batch_size * fAttrHiddenSize);
199 declareVector("initial_cell_state", num_directions * batch_size * fAttrHiddenSize);
200 }
201 // Set the feedforward
202 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
203 declareVector("f_update_gate", ff_size);
204 declareVector("f_reset_gate", ff_size);
205 declareVector("f_hidden_gate", ff_size);
206 // gate results
207 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
208 declareVector("update_gate", hs_size);
209 declareVector("reset_gate", hs_size);
210 declareVector("hidden_gate", hs_size);
211
212 // feedback
213 declareVector("feedback", batch_size * fAttrHiddenSize);
214
215 // hiddden state
216 if (fAttrLayout != 0 || fNY.empty()) {
217 declareVector("hidden_state", hs_size);
218 }
219}
220
221
222template<typename T>
224-> std::string {
225 OpName = "op_" + OpName;
226 std::stringstream out;
227
228 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
229 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
230 size_t input_size = fShapeX[2];
231 size_t num_directions = fShapeW[0];
232
233 auto getVec = [&](std::string const &name) { return "tensor_op_gru_" + fNX + "_" + name; };
234
235 // set the input
236 if (fAttrLayout == 0) {
237 out << SP << fType << " const* " << OpName << "_input = tensor_" << fNX << ";\n";
238 } else {
239 if (fUseSession) {
240 out << SP << fType << " * " << OpName << "_input = " << getVec("input") << ";\n";
241 } else {
242 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
243 }
244 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
245 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
246 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
247 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
248 << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
249 << seq_length * input_size << " + seq * " << input_size << " + i];\n";
250 out << SP << SP << SP << "}\n";
251 out << SP << SP << "}\n";
252 out << SP << "}\n";
253 }
254
255 // Set the initial hidden state
256 if (!fNInitial_h.empty()) {
257 if (fAttrLayout == 0) {
258 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
259 << fNInitial_h << ";\n";
260 } else {
261 if (fUseSession) {
262 out << SP << fType << " * " << OpName << "_initial_hidden_state = " << getVec("initial_hidden_state") << ";\n";
263 } else {
264 out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
265 fAttrHiddenSize << "];\n";
266 }
267 for (size_t direction = 0; direction < num_directions; direction++) {
268 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
269 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
270 out << SP << SP << SP << OpName << "_initial_hidden_state["
271 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
272 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
273 << " + " << direction * fAttrHiddenSize << " + h];\n";
274 out << SP << SP << "}\n";
275 out << SP << "}\n";
276 }
277 }
278 }
279
280 // Set the feedforward
281 size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
282 if (fUseSession) {
283 out << SP << fType << " * " << OpName << "_f_update_gate = " << getVec("f_update_gate") << ";\n";
284 out << SP << fType << " * " << OpName << "_f_reset_gate = " << getVec("f_reset_gate") << ";\n";
285 out << SP << fType << " * " << OpName << "_f_hidden_gate = " << getVec("f_hidden_gate") << ";\n";
286 } else {
287 out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
288 out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
289 out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
290 }
291 // Set the gates
292 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
293 if (fUseSession) {
294 out << SP << fType << " * " << OpName << "_update_gate = " << getVec("update_gate") << ";\n";
295 out << SP << fType << " * " << OpName << "_reset_gate = " << getVec("reset_gate") << ";\n";
296 out << SP << fType << " * " << OpName << "_hidden_gate = " << getVec("hidden_gate") << ";\n";
297 } else {
298 out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
299 out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
300 out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
301 }
302 // Set the hidden state
303 if (fAttrLayout == 0 && !fNY.empty()) {
304 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
305 } else {
306 if (fUseSession) {
307 out << SP << fType << " * " << OpName << "_hidden_state = " << getVec("hidden_state") << ";\n";
308 } else {
309 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
310 }
311 }
312
313 if (fUseSession) {
314 out << SP << fType << " * " << OpName << "_feedback = " << getVec("feedback") << ";\n";
315 } else {
316 out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
317 }
318
319 out << SP << "char " << OpName << "_transA = 'N';\n";
320 out << SP << "char " << OpName << "_transB = 'T';\n";
321 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
322 out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
323 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
324 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
325 if (fType == "float") {
326 out << SP << "float " << OpName << "_alpha = 1.;\n";
327 out << SP << "float " << OpName << "_beta = 0.;\n";
328 }
329 if (!fNB.empty()) {
330 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
331 }
332 out << SP << "int " << OpName << "_incx = 1;\n";
333 out << SP << "int " << OpName << "_incy = 1;\n";
334 out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
335
336 for (size_t direction = 0; direction < num_directions; direction++) {
337 if (direction == 0) {
338 if (fType == "float") {
339 // f_update_gate = input * weight_z^T
340 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
341 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
342 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
343 << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
344 // f_reset_gate = input * weight_r^T
345 size_t wr_offset = fAttrHiddenSize * input_size;
346 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
347 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
348 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
349 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
350 // f_hidden_gate = input * weight_h^T
351 size_t wh_offset = 2 * fAttrHiddenSize * input_size;
352 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
353 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
354 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
355 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
356 }
357 } else {
358 if (fType == "float") {
359 // f_update_gate = input * weight_z^T
360 size_t wz_offset = 3 * fAttrHiddenSize * input_size;
361 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
362 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
363 << fNW << " + " << wz_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
364 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
365 // f_reset_gate = input * weight_r^T
366 size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
367 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
368 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
369 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
370 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
371 // f_hidden_gate = input * weight_h^T
372 size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
373 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
374 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
375 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
376 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
377 }
378 }
379
380 if (!fNB.empty()) {
381 if (direction == 0) {
382 if (fType == "float") {
383 // Add the bias of the weight to f_update_gate
384 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
385 << fNB << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
386 // Add the bias of the recurrence to f_update_gate
387 size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
388 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
389 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
390 << OpName << "_incy);\n";
391 // Add the bias of the weight to f_reset_gate
392 size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
393 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
394 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
395 << OpName << "_incy);\n";
396 // Add the bias of the recurrence to f_reset_gate
397 //size_t rbr_offset = fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
398 size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
399 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
400 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
401 << OpName << "_incy);\n";
402 // Add the bias of the weight to f_hidden_gate
403 size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
404 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
405 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
406 << OpName << "_incy);\n";
407 if (fAttrLinearBeforeReset == 0) {
408 // Add the bias of the recurrence to f_hidden_gate
409 size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
410 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
411 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
412 << "_f_hidden_gate, &" << OpName << "_incy);\n";
413 }
414 }
415 } else {
416 if (fType == "float") {
417 // Add the bias of the weight to f_update_gate
418 size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
419 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
420 << fNB << " + " << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
421 << OpName << "_incy);\n";
422 // Add the bias of the recurrence to f_update_gate
423 // size_t rbz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
424 size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
425 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
426 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
427 << OpName << "_incy);\n";
428 // Add the bias of the weight to f_reset_gate
429 size_t wbr_offset = 7 * batch_size * seq_length * fAttrHiddenSize;
430 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
431 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
432 << OpName << "_incy);\n";
433 // Add the bias of the recurrence to f_reset_gate
434 size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
435 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
436 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
437 << OpName << "_incy);\n";
438 // Add the bias of the weight to f_hidden_gate
439 size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
440 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
441 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
442 << OpName << "_incy);\n";
443 if (fAttrLinearBeforeReset == 0) {
444 // Add the bias of the recurrence to f_hidden_gate
445 size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
446 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
447 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
448 << "_f_hidden_gate, &" << OpName << "_incy);\n";
449 }
450 }
451 }
452 }
453
454 // Copy the feedforward into the gates
455 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
456 out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
457 if (direction == 0) {
458 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
459 << ";\n";
460 } else {
461 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
462 << " + " << batch_size * fAttrHiddenSize << ";\n";
463 }
464 size_t f_seq_size = batch_size * fAttrHiddenSize;
465 out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName
466 << "_f_update_gate + offset + " << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
467 out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName
468 << "_f_reset_gate + offset + " << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
469 out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName
470 << "_f_hidden_gate + offset + " << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
471 out << SP << "}\n";
472
473 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
474 if (fAttrDirection == "backward" || direction == 1) {
475 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
476 } else {
477 out << SP << SP << "size_t index = seq;\n";
478 }
479 out << SP << SP << "int m2 = " << batch_size << ";\n";
480 if (direction == 0) {
481 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
482 << ";\n";
483 } else {
484 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
485 << " + " << batch_size * fAttrHiddenSize << ";\n";
486 }
487 size_t size = batch_size * fAttrHiddenSize;
488 // gate = gate + initial_hidden_state * Recurrence^T
489 out << SP << SP << "if (seq == 0) {\n";
490 if (!fNInitial_h.empty()) {
491 if (direction == 0) {
492 if (fType == "float") {
493 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
494 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
495 << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
496 << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
497 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
498 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
499 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
500 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
501 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
502 }
503 } else { // direction=1
504 if (fType == "float") {
505 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
506 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
507 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
508 << rz_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
509 << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
510 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
511 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
512 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
513 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
514 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
515 }
516 }
517 }
518 out << SP << SP << "} else {\n";
519 // gate = gate + previous_hidden_state * Recurrence^T
520 if (direction == 0) {
521 if (fAttrDirection == "backward") {
522 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
523 << num_directions * batch_size * fAttrHiddenSize << ";\n";
524 } else {
525 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
526 << num_directions * batch_size * fAttrHiddenSize << ";\n";
527 }
528 if (fType == "float") {
529 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
530 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
531 << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
532 << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
533 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
534 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
535 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
536 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
537 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
538 << OpName << "_n);\n";
539 }
540 } else {
541 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
542 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
543 if (fType == "float") {
544 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
545 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
546 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
547 << rz_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
548 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &"
549 << OpName << "_n);\n";
550 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
551 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
552 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
553 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
554 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
555 << OpName << "_n);\n";
556 }
557 }
558 out << SP << SP << "}\n";
559
560 // Clip the elements of the update gate and the reset gate into the range [-fClip, fClip]
561 if (fAttrClip > .0) {
562 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
563 if (fType == "float") {
564 out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip
565 << ") ? " << OpName << "_update_gate[i] : " << -fAttrClip << ";\n";
566 }
567 out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip
568 << ") ? z : " << fAttrClip << ";\n";
569 if (fType == "float") {
570 out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip
571 << ") ? " << OpName << "_reset_gate[i] : " << -fAttrClip << ";\n";
572 }
573 out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip
574 << ") ? r : " << fAttrClip << ";\n";
575 out << SP << SP << "}\n";
576 }
577
578 // Apply the activation function to the update gate and the reset gate
579 if (fAttrActivations[direction * 2] == "Relu") {
580 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
581 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
582 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
583 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
584 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
585 out << SP << SP << "}\n";
586 } else if (fAttrActivations[direction * 2] == "Tanh") {
587 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
588 if (fType == "float") {
589 out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
590 }
591 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
592 if (fType == "float") {
593 out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
594 }
595 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
596 out << SP << SP << "}\n";
597 } else if (fAttrActivations[direction * 2] == "Sigmoid") {
598 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
599 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-"
600 << OpName << "_update_gate[i]));\n";
601 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-"
602 << OpName << "_reset_gate[i]));\n";
603 out << SP << SP << "}\n";
604 } else if (fAttrActivations[direction * 2] == "Affine") {
605 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
606 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
607 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i] + "
608 << fAttrActivationBeta[direction * 2] << ";\n";
609 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
610 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i] + "
611 << fAttrActivationBeta[direction * 2] << ";\n";
612 out << SP << SP << "}\n";
613 } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
614 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
615 if (fType == "float") {
616 out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2]
617 << " * "<< OpName << "_update_gate[i]);\n";
618 }
619 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
620 << fAttrActivationAlpha[direction * 2] << " * (1. - z) / (1. + z);\n";
621 if (fType == "float") {
622 out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2]
623 << " * "<< OpName << "_reset_gate[i]);\n";
624 }
625 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
626 << fAttrActivationAlpha[direction * 2] << " * (1. - r) / (1. + r);\n";
627 out << SP << SP << "}\n";
628 } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
629 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
630 if (fType == "float") {
631 out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * "
632 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
633 out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
634 }
635 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
636 if (fType == "float") {
637 out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * "
638 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
639 out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
640 }
641 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
642 out << SP << SP << "}\n";
643 } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
644 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
645 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
646 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
647 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i];\n";
648 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
649 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
650 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i];\n";
651 out << SP << SP << "}\n";
652 } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
653 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
654 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < "
655 << fAttrActivationAlpha[direction * 2] << ")\n";
656 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
657 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < "
658 << fAttrActivationAlpha[direction * 2] << ")\n";
659 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
660 out << SP << SP << "}";
661 } else if (fAttrActivations[direction * 2] == "Elu") {
662 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
663 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
664 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
665 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
666 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
667 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
668 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
669 out << SP << SP << "}\n";
670 } else if (fAttrActivations[direction * 2] == "Softsign") {
671 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
672 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName
673 << "_update_gate[i] / (1. + abs(" << OpName << "_update_gate[i]));\n";
674 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName
675 << "_reset_gate[i] / (1. + abs(" << OpName << "_reset_gate[i]));\n";
676 out << SP << SP << "}\n";
677 } else { // fAttrActivations[direction * 2] = Softplus
678 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
679 out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp("
680 << OpName << "_update_gate[i]));\n";
681 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp("
682 << OpName << "_reset_gate[i]));\n";
683 out << SP << SP << "}\n";
684 }
685
686 if (fAttrLinearBeforeReset == 0) {
687 out << SP << SP << "if (seq == 0) {\n";
688 if (!fNInitial_h.empty()) {
689 // feedback = reset_gate o initial_hidden_state
690 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
691 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
692 << "_reset_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
693 out << SP << SP << SP << "}\n";
694 }
695 out << SP << SP << "} else {\n";
696 // feedback = reset_gate o previous_hidden_state
697 if (direction == 0) {
698 if (fAttrDirection == "backward") {
699 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
700 << num_directions * batch_size * fAttrHiddenSize << ";\n";
701 } else {
702 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
703 << num_directions * batch_size * fAttrHiddenSize << ";\n";
704 }
705 } else {
706 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
707 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
708 }
709 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
710 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
711 << "_reset_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
712 out << SP << SP << SP << "}\n";
713 out << SP << SP << "}\n";
714 // feedback = feedback * R_h^T
715 size_t rh_offset = (direction == 0) ?
716 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
717 + 2 * fAttrHiddenSize * fAttrHiddenSize;
718 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
719 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_"
720 << fNR << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName
721 << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
722 } else { // fAttrLinearBeforeReset=1
723 // feedback = previous_hidden_state * R_h^T
724 //LM fixes
725 size_t rh_offset = (direction == 0)
726 ? 2 * fAttrHiddenSize * fAttrHiddenSize
727 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
728 out << SP << SP << "if (seq == 0) {\n";
729 if (!fNInitial_h.empty()) {
730 // feedback = W * initial_hidden_state + bias
731 out << SP << SP << SP
732 << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
733 << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
734 << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &"
735 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
736 }
737 out << SP << SP << "} else {\n";
738 // case for seq > 0
739 if (direction == 0) {
740 if (fAttrDirection == "backward") {
741 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
742 << num_directions * batch_size * fAttrHiddenSize << ";\n";
743 } else {
744 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
745 << num_directions * batch_size * fAttrHiddenSize << ";\n";
746 }
747 } else {
748 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
749 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
750 }
751 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
752 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
753 << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
754 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
755 // endif on seq 0 or not
756 out << SP << SP << "}\n";
757 // Add the bias of the recurrence to feedback
758 if (!fNB.empty()) {
759 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
760 : 11 * batch_size * seq_length * fAttrHiddenSize;
761 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName
762 << "_alpha, tensor_" << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, "
763 << OpName << "_feedback, &" << OpName << "_incy);\n";
764 }
765 // feedback = reset_gate o feedback
766 out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
767 out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
768 out << SP << SP << "}\n";
769 }
770
771 // hidden_gate = hidden_gate + feedback
772 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, "
773 << OpName << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &"
774 << OpName << "_incy);\n";
775
776 // Clip the elements of the hidden gate into the range [-fClip, fClip]
777 if (fAttrClip > .0) {
778 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
779 if (fType == "float") {
780 out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip
781 << ") ? " << OpName << "_hidden_gate[i] : " << -fAttrClip << ";\n";
782 }
783 out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : "
784 << fAttrClip << ";\n";
785 out << SP << SP << "}\n";
786 }
787
788 // Apply the activation function to the hidden gate
789 if (fAttrActivations[direction * 2 + 1] == "Relu") {
790 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
791 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
792 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
793 out << SP << SP << "}\n";
794 } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
795 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
796 if (fType == "float") {
797 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
798 }
799 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
800 out << SP << SP << "}\n";
801 } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
802 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
803 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
804 << "_hidden_gate[i]));\n";
805 out << SP << SP << "}\n";
806 } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
807 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
808 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
809 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i] + "
810 << fAttrActivationBeta[direction * 2 + 1] << ";\n";
811 out << SP << SP << "}\n";
812 } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
813 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
814 if (fType == "float") {
815 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1]
816 << " * "<< OpName << "_hidden_gate[i]);\n";
817 }
818 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
819 << fAttrActivationAlpha[direction * 2 + 1] << " * (1. - ex) / (1. + ex);\n";
820 out << SP << SP << "}\n";
821 } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
822 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
823 if (fType == "float") {
824 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * "
825 << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
826 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
827 }
828 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
829 out << SP << SP << "}\n";
830 } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
831 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
832 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
833 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
834 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i];\n";
835 out << SP << SP << "}\n";
836 } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
837 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
838 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < "
839 << fAttrActivationAlpha[direction * 2 + 1] << ")\n";
840 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
841 out << SP << SP << "}";
842 } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
843 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
844 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
845 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
846 << fAttrActivationAlpha[direction * 2 + 1] << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
847 out << SP << SP << "}\n";
848 } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
849 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
850 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName
851 << "_hidden_gate[i] / (1. + abs(" << OpName << "_hidden_gate[i]));\n";
852 out << SP << SP << "}\n";
853 } else { // fAttrActivations[direction * 2 + 1] = Softplus
854 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
855 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp("
856 << OpName << "_hidden_gate[i]));\n";
857 out << SP << SP << "}\n";
858 }
859
860 // hidden_state = (1 - update_gate) o hidden_gate
861 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
862 out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName
863 << "_update_gate[i]) * " << OpName << "_hidden_gate[i];\n";
864 out << SP << SP << "}\n";
865
866 out << SP << SP << "if (seq == 0) {\n";
867 if (!fNInitial_h.empty()) {
868 // hidden_state += update_gate o initial_hidden_state
869 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
870 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
871 << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
872 out << SP << SP << SP << "}\n";
873 }
874 out << SP << SP << "} else {\n";
875 // hidden_state += update_gate o previous_hidden_state
876 if (direction == 0) {
877 if (fAttrDirection == "backward") {
878 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
879 << num_directions * batch_size * fAttrHiddenSize << ";\n";
880 } else {
881 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
882 << num_directions * batch_size * fAttrHiddenSize << ";\n";
883 }
884 } else {
885 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
886 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
887 }
888 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
889 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
890 << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
891 out << SP << SP << SP << "}\n";
892 out << SP << SP << "}\n";
893
894 out << SP << "}\n";
895 }
896
897 // Padding the hidden state for GRU with different sequence lengths
898 if (!fNSequence_lens.empty()) {
899 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
900 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
901 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
902 for (size_t direction = 0; direction < num_directions; direction++) {
903 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
904 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
905 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
906 << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
907 out << SP << SP << SP << SP << SP << "}\n";
908 }
909 out << SP << SP << SP << "}\n";
910 out << SP << SP << "}\n";
911 out << SP << "}\n";
912 }
913
914 // Copy the hidden state into y and y_h
915 if (fAttrLayout == 0) {
916 if (!fNY_h.empty()) {
917 // Copy hidden_state into Y_h
918 if (fNSequence_lens.empty()) {
919 size_t yh_size = batch_size * fAttrHiddenSize;
920 if (fAttrDirection == "backward") {
921 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
922 << yh_size << ", tensor_" << fNY_h << ");\n";
923 } else {
924 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
925 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
926 << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
927 }
928 if (num_directions == 2) {
929 out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
930 << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
931 }
932 } else { // GRU with different sequence lengths
933 if (fAttrDirection == "backward") {
934 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
935 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
936 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
937 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
938 out << SP << "}\n";
939 } else {
940 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
941 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
942 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
943 << " + batch * " << fAttrHiddenSize << ";\n";
944 out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
945 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
946 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
947 out << SP << "}\n";
948 }
949 if (num_directions == 2) {
950 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
951 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
952 << " + batch * " << fAttrHiddenSize << ";\n";
953 out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
954 << " + batch * " << fAttrHiddenSize << ";\n";
955 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
956 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
957 out << SP << "}\n";
958 }
959 }
960 }
961 } else { // fAttrLayout=1
962 if (!fNY.empty()) {
963 // Copy hidden_state into Y
964 for (size_t direction = 0; direction < num_directions; direction++) {
965 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
966 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
967 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
968 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
969 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
970 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
971 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
972 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
973 out << SP << SP << "}\n";
974 out << SP << "}\n";
975 }
976 }
977 if (!fNY_h.empty()) {
978 // Copy the hidden_state into Y_h
979 if (fAttrDirection == "backward") {
980 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
981 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
982 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
983 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
984 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
985 out << SP << "}\n";
986 } else {
987 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
988 if (fNSequence_lens.empty()) {
989 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
990 } else {
991 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
992 }
993 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
994 << " + batch * " << fAttrHiddenSize << ";\n";
995 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
996 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
997 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
998 out << SP << "}\n";
999 }
1000 if (num_directions == 2) {
1001 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1002 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1003 << fAttrHiddenSize << ";\n";
1004 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1005 << fAttrHiddenSize << ";\n";
1006 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1007 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1008 out << SP << "}\n";
1009 }
1010 }
1011 }
1012
1013 return out.str();
1014}
1015
1016} // namespace SOFIE
1017} // namespace Experimental
1018} // namespace TMVA
1019
1020#endif
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
char name[80]
Definition TGX11.cxx:110
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:200
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:95
const ETensorType & GetTensorType(std::string name) const
Definition RModel.cxx:67
const std::vector< size_t > & GetTensorShape(std::string name) const
Definition RModel.cxx:29
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:261
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:252
void Initialize(RModel &) override
Initialize the model.
std::string Generate(std::string) override
Generate the inference code.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >) override
Infers the shape of the output tensors.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >) override
Infers the type of the output tensors.
const Int_t n
Definition legend1.C:16
ETensorType ConvertStringToType(std::string type)
create variable transformations