17-> std::vector<std::vector<size_t>> {
18 size_t num_directions =
input[1][0];
19 size_t hidden_size =
input[1][1] / 3;
21 size_t seq_length =
input[0][0];
22 size_t batch_size =
input[0][1];
23 std::vector<std::vector<size_t>> ret(
24 {{seq_length, num_directions, batch_size, hidden_size},
25 {num_directions, batch_size, hidden_size}});
28 size_t batch_size =
input[0][0];
29 size_t seq_length =
input[0][1];
30 std::vector<std::vector<size_t>> ret(
31 {{batch_size, seq_length, num_directions, hidden_size},
32 {batch_size, num_directions, hidden_size}});
42 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNX +
" is not found in model.");
46 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNX +
" is not of 3 dimensions.");
49 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNW +
" is not found in model.");
53 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNW +
" is not of 3 dimensions.");
56 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNR +
" is not found in model.");
60 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
fNR +
" is not of 3 dimensions.");
64 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " +
fNB +
" is not found in model.");
68 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " +
fNB +
" is not of 2 or 4 dimensions.");
73 size_t num_directions =
fShapeW[0];
76 if (
fType ==
"float") {
77 float *original_bias =
static_cast<float*
>(original_data.get());
78 float *new_bias =
new float[num_directions * 6 * seq_length * batch_size *
fAttrHiddenSize];
79 for (
size_t direction = 0; direction < num_directions; direction++) {
80 for (
size_t i = 0;
i < 6;
i++) {
81 for (
size_t seq = 0; seq < seq_length; seq++) {
82 for (
size_t batch = 0; batch < batch_size; batch++) {
87 std::copy(original_bias + bias_offset, original_bias + bias_offset +
fAttrHiddenSize,
94 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size,
fAttrHiddenSize};
95 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<
float[]>());
103 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
105 "is not found in model.");
109 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
111 " is not of 1 dimension.");
116 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
121 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
131 if (!
fNY_h.empty()) {
139 if (activation !=
"Relu" && activation !=
"Tanh" &&
140 activation !=
"Sigmoid" && activation !=
"Affine" &&
141 activation !=
"LeakyRelu" && activation !=
"ThresholdRelu" &&
142 activation !=
"ScaledTanh" && activation !=
"HardSigmoid" &&
143 activation !=
"Elu" && activation !=
"Softsign" &&
144 activation !=
"Softplus") {
145 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
146 activation +
" not implemented");
153 throw std::runtime_error(
154 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
158 throw std::runtime_error(
159 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
160 std::to_string(
fShapeW[1] / 3));
163 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " +
165 " must be 0 (timewise) or 1 (batchwise)");
168 throw std::runtime_error(
170 +
" must be 0 or 1.");
230 OpName =
"op_" + OpName;
231 std::stringstream out;
235 size_t input_size =
fShapeX[2];
236 size_t num_directions =
fShapeW[0];
240 out <<
SP <<
fType <<
" *" << OpName <<
"_input = tensor_" <<
fNX <<
";\n";
243 out <<
SP <<
fType <<
" * " << OpName <<
"_input = fVec_" << OpName <<
"_input.data();\n";
245 out <<
SP <<
fType <<
" " << OpName <<
"_input[" << seq_length * batch_size * input_size <<
"];\n";
247 out <<
SP <<
"for(size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
248 out <<
SP <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
249 out <<
SP <<
SP <<
SP <<
"for(size_t i = 0; i < " << input_size <<
"; i++) {\n";
250 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input[seq * " << batch_size * input_size
251 <<
" + batch * " << input_size <<
" + i] = " <<
"tensor_" <<
fNX <<
"[batch * "
252 << seq_length * input_size <<
" + seq * " << input_size <<
" + i];\n";
253 out <<
SP <<
SP <<
SP <<
"}\n";
254 out <<
SP <<
SP <<
"}\n";
261 out <<
SP <<
fType <<
" *" << OpName <<
"_initial_hidden_state = " <<
" tensor_"
265 out <<
SP <<
fType <<
" * " << OpName <<
"_initial_hidden_state = fVec_" << OpName
266 <<
"_initial_hidden_state.data();\n";
268 out <<
SP <<
fType <<
" " << OpName <<
"_initial_hidden_state[" << num_directions * batch_size *
271 for (
size_t direction = 0; direction < num_directions; direction++) {
272 out <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
274 out <<
SP <<
SP <<
SP << OpName <<
"_initial_hidden_state["
278 out <<
SP <<
SP <<
"}\n";
287 out <<
SP <<
fType <<
" * " << OpName <<
"_f_update_gate = fVec_" << OpName <<
"_f_update_gate.data();\n";
288 out <<
SP <<
fType <<
" * " << OpName <<
"_f_reset_gate = fVec_" << OpName <<
"_f_reset_gate.data();\n";
289 out <<
SP <<
fType <<
" * " << OpName <<
"_f_hidden_gate = fVec_" << OpName <<
"_f_hidden_gate.data();\n";
291 out <<
SP <<
fType <<
" " << OpName <<
"_f_update_gate[" << feedforward_size <<
"] = {0};\n";
292 out <<
SP <<
fType <<
" " << OpName <<
"_f_reset_gate[" << feedforward_size <<
"] = {0};\n";
293 out <<
SP <<
fType <<
" " << OpName <<
"_f_hidden_gate[" << feedforward_size <<
"] = {0};\n";
296 size_t hidden_state_size = seq_length * num_directions * batch_size *
fAttrHiddenSize;
298 out <<
SP <<
fType <<
" * " << OpName <<
"_update_gate = fVec_" << OpName <<
"_update_gate.data();\n";
299 out <<
SP <<
fType <<
" * " << OpName <<
"_reset_gate = fVec_" << OpName <<
"_reset_gate.data();\n";
300 out <<
SP <<
fType <<
" * " << OpName <<
"_hidden_gate = fVec_" << OpName <<
"_hidden_gate.data();\n";
302 out <<
SP <<
fType <<
" " << OpName <<
"_update_gate[" << hidden_state_size <<
"] = {0};\n";
303 out <<
SP <<
fType <<
" " << OpName <<
"_reset_gate[" << hidden_state_size <<
"] = {0};\n";
304 out <<
SP <<
fType <<
" " << OpName <<
"_hidden_gate[" << hidden_state_size <<
"] = {0};\n";
308 out <<
SP <<
fType <<
" *" << OpName <<
"_hidden_state = tensor_" <<
fNY <<
";\n";
311 out <<
SP <<
fType <<
" * " << OpName <<
"_hidden_state = fVec_" << OpName <<
"_hidden_state.data();\n";
313 out <<
SP <<
fType <<
" " << OpName <<
"_hidden_state[" << hidden_state_size <<
"] = {0};\n";
318 out <<
SP <<
fType <<
" * " << OpName <<
"_feedback = fVec_" << OpName <<
"_feedback.data();\n";
323 out <<
SP <<
"char " << OpName <<
"_transA = 'N';\n";
324 out <<
SP <<
"char " << OpName <<
"_transB = 'T';\n";
325 out <<
SP <<
"int " << OpName <<
"_m = " << seq_length * batch_size <<
";\n";
326 out <<
SP <<
"int " << OpName <<
"_m2 = " << batch_size <<
";\n";
328 out <<
SP <<
"int " << OpName <<
"_k = " << input_size <<
";\n";
329 if (
fType ==
"float") {
330 out <<
SP <<
"float " << OpName <<
"_alpha = 1.;\n";
331 out <<
SP <<
"float " << OpName <<
"_beta = 0.;\n";
334 out <<
SP <<
"int " << OpName <<
"_bias_size = " << seq_length * batch_size *
fAttrHiddenSize <<
";\n";
336 out <<
SP <<
"int " << OpName <<
"_incx = 1;\n";
337 out <<
SP <<
"int " << OpName <<
"_incy = 1;\n";
338 out <<
SP <<
"int " << OpName <<
"_feedback_size = " << batch_size *
fAttrHiddenSize <<
";\n";
340 for (
size_t direction = 0; direction < num_directions; direction++) {
341 if (direction == 0) {
342 if (
fType ==
"float") {
344 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
345 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
346 <<
fNW <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &"
347 << OpName <<
"_beta, " << OpName <<
"_f_update_gate, &" << OpName <<
"_n);\n";
350 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
351 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
352 <<
fNW <<
" + " << wr_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
353 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_f_reset_gate, &" << OpName <<
"_n);\n";
356 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
357 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
358 <<
fNW <<
" + " << wh_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
359 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_f_hidden_gate, &" << OpName <<
"_n);\n";
362 if (
fType ==
"float") {
365 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
366 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
367 <<
fNW <<
" + " << wz_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
368 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_f_update_gate, &" << OpName <<
"_n);\n";
371 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
372 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
373 <<
fNW <<
" + " << wr_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
374 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_f_reset_gate, &" << OpName <<
"_n);\n";
377 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
378 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
379 <<
fNW <<
" + " << wh_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
380 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_f_hidden_gate, &" << OpName <<
"_n);\n";
385 if (direction == 0) {
386 if (
fType ==
"float") {
388 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
389 <<
fNB <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_update_gate, &" << OpName <<
"_incy);\n";
392 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
393 <<
fNB <<
" + " << rbz_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_update_gate, &"
394 << OpName <<
"_incy);\n";
397 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
398 <<
fNB <<
" + " << wbr_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_reset_gate, &"
399 << OpName <<
"_incy);\n";
403 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
404 <<
fNB <<
" + " << rbr_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_reset_gate, &"
405 << OpName <<
"_incy);\n";
408 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
409 <<
fNB <<
" + " << wbh_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_hidden_gate, &"
410 << OpName <<
"_incy);\n";
414 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
415 <<
fNB <<
" + " << rbh_offset <<
", &" << OpName <<
"_incx, " << OpName
416 <<
"_f_hidden_gate, &" << OpName <<
"_incy);\n";
420 if (
fType ==
"float") {
423 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
424 <<
fNB <<
" + " << wbz_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_update_gate, &"
425 << OpName <<
"_incy);\n";
429 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
430 <<
fNB <<
" + " << rbz_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_update_gate, &"
431 << OpName <<
"_incy);\n";
434 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
435 <<
fNB <<
" + " << wbr_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_reset_gate, &"
436 << OpName <<
"_incy);\n";
439 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
440 <<
fNB <<
" + " << rbr_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_reset_gate, &"
441 << OpName <<
"_incy);\n";
444 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
445 <<
fNB <<
" + " << wbh_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_f_hidden_gate, &"
446 << OpName <<
"_incy);\n";
450 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
451 <<
fNB <<
" + " << rbh_offset <<
", &" << OpName <<
"_incx, " << OpName
452 <<
"_f_hidden_gate, &" << OpName <<
"_incy);\n";
459 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
461 if (direction == 0) {
462 out <<
SP <<
SP <<
"size_t gate_offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
465 out <<
SP <<
SP <<
"size_t gate_offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
469 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_f_update_gate + offset, " << OpName
470 <<
"_f_update_gate + offset + " << f_seq_size <<
", " << OpName <<
"_update_gate + gate_offset);\n";
471 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_f_reset_gate + offset, " << OpName
472 <<
"_f_reset_gate + offset + " << f_seq_size <<
", " << OpName <<
"_reset_gate + gate_offset);\n";
473 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_f_hidden_gate + offset, " << OpName
474 <<
"_f_hidden_gate + offset + " << f_seq_size <<
", " << OpName <<
"_hidden_gate + gate_offset);\n";
477 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
479 out <<
SP <<
SP <<
"size_t index = " << seq_length - 1 <<
" - seq;\n";
481 out <<
SP <<
SP <<
"size_t index = seq;\n";
483 out <<
SP <<
SP <<
"int m2 = " << batch_size <<
";\n";
484 if (direction == 0) {
485 out <<
SP <<
SP <<
"size_t offset = index * " << num_directions * batch_size *
fAttrHiddenSize
488 out <<
SP <<
SP <<
"size_t offset = index * " << num_directions * batch_size *
fAttrHiddenSize
493 out <<
SP <<
SP <<
"if (seq == 0) {\n";
495 if (direction == 0) {
496 if (
fType ==
"float") {
497 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
498 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
", &"
499 << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
500 <<
"_alpha, " << OpName <<
"_update_gate + offset, &" << OpName <<
"_n);\n";
502 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
503 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
504 << rr_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
505 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_reset_gate + offset, &" << OpName <<
"_n);\n";
508 if (
fType ==
"float") {
510 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
511 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
512 << rz_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
513 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_update_gate + offset, &" << OpName <<
"_n);\n";
515 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
516 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
517 << rr_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
518 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_reset_gate + offset, &" << OpName <<
"_n);\n";
522 out <<
SP <<
SP <<
"} else {\n";
524 if (direction == 0) {
526 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
529 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
532 if (
fType ==
"float") {
533 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
534 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
", &"
535 << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
536 << OpName <<
"_alpha, " << OpName <<
"_update_gate + offset, &" << OpName <<
"_n);\n";
538 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
539 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
540 << rr_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
541 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_reset_gate + offset, &"
542 << OpName <<
"_n);\n";
545 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
547 if (
fType ==
"float") {
549 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
550 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
551 << rz_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
552 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_update_gate + offset, &"
553 << OpName <<
"_n);\n";
555 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
556 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
557 << rr_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
558 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_reset_gate + offset, &"
559 << OpName <<
"_n);\n";
562 out <<
SP <<
SP <<
"}\n";
566 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
567 if (
fType ==
"float") {
568 out <<
SP <<
SP <<
SP <<
"float z = (" << OpName <<
"_update_gate[i] > " << -
fAttrClip
569 <<
") ? " << OpName <<
"_update_gate[i] : " << -
fAttrClip <<
";\n";
571 out <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = (z < " <<
fAttrClip
573 if (
fType ==
"float") {
574 out <<
SP <<
SP <<
SP <<
"float r = (" << OpName <<
"_reset_gate[i] > " << -
fAttrClip
575 <<
") ? " << OpName <<
"_reset_gate[i] : " << -
fAttrClip <<
";\n";
577 out <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = (r < " <<
fAttrClip
579 out <<
SP <<
SP <<
"}\n";
584 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
585 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_update_gate[i] < 0.)\n";
586 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = 0.;\n";
587 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_reset_gate[i] < 0.)\n";
588 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = 0.;\n";
589 out <<
SP <<
SP <<
"}\n";
591 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
592 if (
fType ==
"float") {
593 out <<
SP <<
SP <<
SP <<
"float z = exp(-2 * " << OpName <<
"_update_gate[i]);\n";
595 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = (1. - z) / (1. + z);\n";
596 if (
fType ==
"float") {
597 out <<
SP <<
SP <<
SP <<
"float r = exp(-2 * " << OpName <<
"_reset_gate[i]);\n";
599 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = (1. - r) / (1. + r);\n";
600 out <<
SP <<
SP <<
"}\n";
602 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
603 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = 1. / (1. + exp(-"
604 << OpName <<
"_update_gate[i]));\n";
605 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = 1. / (1. + exp(-"
606 << OpName <<
"_reset_gate[i]));\n";
607 out <<
SP <<
SP <<
"}\n";
609 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
610 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = "
613 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = "
616 out <<
SP <<
SP <<
"}\n";
618 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
619 if (
fType ==
"float") {
621 <<
" * "<< OpName <<
"_update_gate[i]);\n";
623 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = "
625 if (
fType ==
"float") {
627 <<
" * "<< OpName <<
"_reset_gate[i]);\n";
629 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = "
631 out <<
SP <<
SP <<
"}\n";
633 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
634 if (
fType ==
"float") {
637 out <<
SP <<
SP <<
SP <<
"float zb = (za > 0.) ? za : 0.;\n";
639 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
640 if (
fType ==
"float") {
643 out <<
SP <<
SP <<
SP <<
"float rb = (ra > 0.) ? ra : 0.;\n";
645 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
646 out <<
SP <<
SP <<
"}\n";
648 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
649 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_update_gate[i] < 0.)\n";
650 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = "
652 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_reset_gate[i] < 0.)\n";
653 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = "
655 out <<
SP <<
SP <<
"}\n";
657 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
658 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_update_gate[i] < "
660 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = 0.;\n";
661 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_reset_gate[i] < "
663 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = 0.;\n";
664 out <<
SP <<
SP <<
"}";
666 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
667 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_update_gate[i] < 0.)\n";
668 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = "
670 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_reset_gate[i] < 0.)\n";
671 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = "
673 out <<
SP <<
SP <<
"}\n";
675 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
676 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = " << OpName
677 <<
"_update_gate[i] / (1. + abs(" << OpName <<
"_update_gate[i]));\n";
678 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = " << OpName
679 <<
"_reset_gate[i] / (1. + abs(" << OpName <<
"_reset_gate[i]));\n";
680 out <<
SP <<
SP <<
"}\n";
682 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
683 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_update_gate[i] = log(1. + exp("
684 << OpName <<
"_update_gate[i]));\n";
685 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_reset_gate[i] = log(1. + exp("
686 << OpName <<
"_reset_gate[i]));\n";
687 out <<
SP <<
SP <<
"}\n";
691 out <<
SP <<
SP <<
"if (seq == 0) {\n";
694 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
695 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_feedback[i] = " << OpName
696 <<
"_reset_gate[i + offset] * " << OpName <<
"_initial_hidden_state[i];\n";
697 out <<
SP <<
SP <<
SP <<
"}\n";
699 out <<
SP <<
SP <<
"} else {\n";
701 if (direction == 0) {
703 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
706 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
710 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * " << num_directions
713 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
714 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_feedback[i] = " << OpName
715 <<
"_reset_gate[i + offset] * " << OpName <<
"_hidden_state[i + previous_offset];\n";
716 out <<
SP <<
SP <<
SP <<
"}\n";
717 out <<
SP <<
SP <<
"}\n";
719 size_t rh_offset = (direction == 0) ?
722 out <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
723 << OpName <<
"_n, &" << OpName <<
"_m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_"
724 <<
fNR <<
" + " << rh_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_feedback, &" << OpName
725 <<
"_n, &" << OpName <<
"_beta, " << OpName <<
"_feedback, &" << OpName <<
"_n);\n";
729 size_t rh_offset = (direction == 0)
732 out <<
SP <<
SP <<
"if (seq == 0) {\n";
736 <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &" << OpName <<
"_n, &"
737 << OpName <<
"_m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
738 << rh_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &"
739 << OpName <<
"_n, &" << OpName <<
"_beta, " << OpName <<
"_feedback, &" << OpName <<
"_n);\n";
741 out <<
SP <<
SP <<
"} else {\n";
743 if (direction == 0) {
745 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
748 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
752 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * " << num_directions
755 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
756 << OpName <<
"_n, &" << OpName <<
"_m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR
757 <<
" + " << rh_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
758 << OpName <<
"_n, &" << OpName <<
"_beta, " << OpName <<
"_feedback, &" << OpName <<
"_n);\n";
760 out <<
SP <<
SP <<
"}\n";
763 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length *
fAttrHiddenSize
765 out <<
SP <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_feedback_size, &" << OpName
766 <<
"_alpha, tensor_" <<
fNB <<
" + " << rbh_offset <<
", &" << OpName <<
"_incx, "
767 << OpName <<
"_feedback, &" << OpName <<
"_incy);\n";
770 out <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
771 out <<
SP <<
SP <<
SP << OpName <<
"_feedback[i] *= " << OpName <<
"_reset_gate[i + offset];\n";
772 out <<
SP <<
SP <<
"}\n";
776 out <<
SP <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_feedback_size, &" << OpName <<
"_alpha, "
777 << OpName <<
"_feedback, &" << OpName <<
"_incx, " << OpName <<
"_hidden_gate + offset, &"
778 << OpName <<
"_incy);\n";
782 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
783 if (
fType ==
"float") {
784 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_hidden_gate[i] > " << -
fAttrClip
785 <<
") ? " << OpName <<
"_hidden_gate[i] : " << -
fAttrClip <<
";\n";
787 out <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = (x < " <<
fAttrClip <<
") ? x : "
789 out <<
SP <<
SP <<
"}\n";
794 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
795 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_hidden_gate[i] < 0.)\n";
796 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = 0.;\n";
797 out <<
SP <<
SP <<
"}\n";
799 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
800 if (
fType ==
"float") {
801 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_hidden_gate[i]);\n";
803 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
804 out <<
SP <<
SP <<
"}\n";
806 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
807 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = 1. / (1. + exp(-" << OpName
808 <<
"_hidden_gate[i]));\n";
809 out <<
SP <<
SP <<
"}\n";
811 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
812 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = "
815 out <<
SP <<
SP <<
"}\n";
817 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
818 if (
fType ==
"float") {
820 <<
" * "<< OpName <<
"_hidden_gate[i]);\n";
822 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = "
824 out <<
SP <<
SP <<
"}\n";
826 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
827 if (
fType ==
"float") {
830 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
832 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
833 out <<
SP <<
SP <<
"}\n";
835 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
836 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_hidden_gate[i] < 0.)\n";
837 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = "
839 out <<
SP <<
SP <<
"}\n";
841 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
842 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_hidden_gate[i] < "
844 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = 0.;\n";
845 out <<
SP <<
SP <<
"}";
847 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
848 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_hidden_gate[i] < 0.)\n";
849 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = "
850 <<
fAttrActivationAlpha[direction * 2 + 1] <<
" * exp(" << OpName <<
"_hidden_gate[i] - 1.);\n";
851 out <<
SP <<
SP <<
"}\n";
853 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
854 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = " << OpName
855 <<
"_hidden_gate[i] / (1. + abs(" << OpName <<
"_hidden_gate[i]));\n";
856 out <<
SP <<
SP <<
"}\n";
858 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
859 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_gate[i] = log(1. + exp("
860 << OpName <<
"_hidden_gate[i]));\n";
861 out <<
SP <<
SP <<
"}\n";
865 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
866 out <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[i] = ( 1. - " << OpName
867 <<
"_update_gate[i]) * " << OpName <<
"_hidden_gate[i];\n";
868 out <<
SP <<
SP <<
"}\n";
870 out <<
SP <<
SP <<
"if (seq == 0) {\n";
873 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
874 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[i + offset] += " << OpName
875 <<
"_update_gate[i + offset] * " << OpName <<
"_initial_hidden_state[i];\n";
876 out <<
SP <<
SP <<
SP <<
"}\n";
878 out <<
SP <<
SP <<
"} else {\n";
880 if (direction == 0) {
882 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
885 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
889 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
892 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
893 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[i + offset] += " << OpName
894 <<
"_update_gate[i + offset] * " << OpName <<
"_hidden_state[i + previous_offset];\n";
895 out <<
SP <<
SP <<
SP <<
"}\n";
896 out <<
SP <<
SP <<
"}\n";
903 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
904 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
906 for (
size_t direction = 0; direction < num_directions; direction++) {
908 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[seq * "
913 out <<
SP <<
SP <<
SP <<
"}\n";
914 out <<
SP <<
SP <<
"}\n";
920 if (!
fNY_h.empty()) {
925 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state, " << OpName <<
"_hidden_state + "
926 << yh_size <<
", tensor_" <<
fNY_h <<
");\n";
929 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + " <<
offset <<
", " << OpName
930 <<
"_hidden_state + " <<
offset <<
" + " << yh_size <<
", tensor_" <<
fNY_h <<
");\n";
932 if (num_directions == 2) {
933 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + " << yh_size <<
", " << OpName
934 <<
"_hidden_state + " << 2 * yh_size <<
", tensor_" <<
fNY_h <<
" + " << yh_size <<
");\n";
938 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
940 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
944 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
946 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
949 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
950 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
953 if (num_directions == 2) {
954 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
959 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
960 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
968 for (
size_t direction = 0; direction < num_directions; direction++) {
969 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
970 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
973 out <<
SP <<
SP <<
SP <<
"size_t y_offset = batch * " << seq_length * num_directions *
fAttrHiddenSize
975 out <<
SP <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
976 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY <<
" + y_offset);\n";
977 out <<
SP <<
SP <<
"}\n";
981 if (!
fNY_h.empty()) {
984 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
986 out <<
SP <<
SP <<
"size_t yh_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
987 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
988 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
991 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
993 out <<
SP <<
SP <<
"size_t seq = " << seq_length - 1 <<
";\n";
997 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
999 out <<
SP <<
SP <<
"size_t yh_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1000 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1001 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";
1004 if (num_directions == 2) {
1005 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1008 out <<
SP <<
SP <<
"size_t yh_offset = batch * " << num_directions *
fAttrHiddenSize <<
" + "
1010 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1011 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + yh_offset);\n";