43 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not found in model.");
46 if (fShapeX.size() != 3) {
47 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not of 3 dimensions.");
50 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not found in model.");
53 if (fShapeW.size() != 3) {
54 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not of 3 dimensions.");
57 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not found in model.");
60 if (fShapeR.size() != 3) {
61 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not of 3 dimensions.");
65 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not found in model.");
68 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
69 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not of 2 or 4 dimensions.");
71 if (fShapeB.size() == 2) {
75 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
76 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
77 if (fType ==
"float") {
81 for (
size_t i = 0; i < 6; i++) {
102 if (!fNSequence_lens.empty()) {
104 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
106 "is not found in model.");
109 if (fShapeSequence_lens.size() != 1) {
110 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
112 " is not of 1 dimension.");
115 if (!fNInitial_h.empty()) {
117 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
118 fNInitial_h +
" is not found in model.");
121 if (fShapeInitial_h.size() != 3) {
122 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " +
123 fNInitial_h +
" is not of 3 dimensions.");
127 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
132 if (!fNY_h.empty()) {
133 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
146 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
150 if (fAttrDirection ==
"reverse") fAttrDirection =
"backward";
151 if (fAttrDirection !=
"forward" && fAttrDirection !=
"backward" &&
152 fAttrDirection !=
"reverse" &&
153 fAttrDirection !=
"bidirectional") {
154 throw std::runtime_error(
155 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
158 if (3 * fAttrHiddenSize != fShapeW[1]) {
159 throw std::runtime_error(
160 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
161 std::to_string(fShapeW[1] / 3));
163 if (fAttrLayout > 1) {
164 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " +
165 std::to_string(fAttrLayout) +
166 " must be 0 (timewise) or 1 (batchwise)");
168 if (fAttrLinearBeforeReset > 1) {
169 throw std::runtime_error(
170 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
171 +
" must be 0 or 1.");
173 if (fAttrActivations.empty()) {
174 if (fAttrDirection ==
"bidirectional") {
175 fAttrActivations = {
"Sigmoid",
"Tanh",
"Sigmoid",
"Tanh"};
177 fAttrActivations = {
"Sigmoid",
"Tanh"};
184 std::string
opName =
"op_gru_" + fNX;
187 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
188 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
196 if (fAttrLayout != 0) {
216 if (fAttrLayout != 0 || fNY.empty()) {
226 std::stringstream out;
228 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
229 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
233 auto getVec = [&](std::string
const &
name) {
return "tensor_op_gru_" + fNX +
"_" +
name; };
236 if (fAttrLayout == 0) {
237 out << SP << fType <<
" const* " <<
OpName <<
"_input = tensor_" << fNX <<
";\n";
240 out << SP << fType <<
" * " <<
OpName <<
"_input = " <<
getVec(
"input") <<
";\n";
244 out << SP <<
"for(size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
245 out << SP << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
246 out << SP << SP << SP <<
"for(size_t i = 0; i < " <<
input_size <<
"; i++) {\n";
248 <<
" + batch * " <<
input_size <<
" + i] = " <<
"tensor_" << fNX <<
"[batch * "
250 out << SP << SP << SP <<
"}\n";
251 out << SP << SP <<
"}\n";
256 if (!fNInitial_h.empty()) {
257 if (fAttrLayout == 0) {
258 out << SP << fType <<
" *" <<
OpName <<
"_initial_hidden_state = " <<
" tensor_"
259 << fNInitial_h <<
";\n";
262 out << SP << fType <<
" * " <<
OpName <<
"_initial_hidden_state = " <<
getVec(
"initial_hidden_state") <<
";\n";
265 fAttrHiddenSize <<
"];\n";
268 out << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
269 out << SP << SP <<
"for(size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
270 out << SP << SP << SP <<
OpName <<
"_initial_hidden_state["
272 <<
" + h] = tensor_" << fNInitial_h <<
"[batch * " <<
num_directions * fAttrHiddenSize
273 <<
" + " <<
direction * fAttrHiddenSize <<
" + h];\n";
274 out << SP << SP <<
"}\n";
283 out << SP << fType <<
" * " <<
OpName <<
"_f_update_gate = " <<
getVec(
"f_update_gate") <<
";\n";
284 out << SP << fType <<
" * " <<
OpName <<
"_f_reset_gate = " <<
getVec(
"f_reset_gate") <<
";\n";
285 out << SP << fType <<
" * " <<
OpName <<
"_f_hidden_gate = " <<
getVec(
"f_hidden_gate") <<
";\n";
294 out << SP << fType <<
" * " <<
OpName <<
"_update_gate = " <<
getVec(
"update_gate") <<
";\n";
295 out << SP << fType <<
" * " <<
OpName <<
"_reset_gate = " <<
getVec(
"reset_gate") <<
";\n";
296 out << SP << fType <<
" * " <<
OpName <<
"_hidden_gate = " <<
getVec(
"hidden_gate") <<
";\n";
303 if (fAttrLayout == 0 && !fNY.empty()) {
304 out << SP << fType <<
" *" <<
OpName <<
"_hidden_state = tensor_" << fNY <<
";\n";
307 out << SP << fType <<
" * " <<
OpName <<
"_hidden_state = " <<
getVec(
"hidden_state") <<
";\n";
314 out << SP << fType <<
" * " <<
OpName <<
"_feedback = " <<
getVec(
"feedback") <<
";\n";
316 out << SP << fType <<
" " <<
OpName <<
"_feedback[" <<
batch_size * fAttrHiddenSize <<
"] = {0};\n";
319 out << SP <<
"char " <<
OpName <<
"_transA = 'N';\n";
320 out << SP <<
"char " <<
OpName <<
"_transB = 'T';\n";
323 out << SP <<
"int " <<
OpName <<
"_n = " << fAttrHiddenSize <<
";\n";
325 if (fType ==
"float") {
326 out << SP <<
"float " <<
OpName <<
"_alpha = 1.;\n";
327 out << SP <<
"float " <<
OpName <<
"_beta = 0.;\n";
332 out << SP <<
"int " <<
OpName <<
"_incx = 1;\n";
333 out << SP <<
"int " <<
OpName <<
"_incy = 1;\n";
334 out << SP <<
"int " <<
OpName <<
"_feedback_size = " <<
batch_size * fAttrHiddenSize <<
";\n";
338 if (fType ==
"float") {
340 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
346 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
352 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
358 if (fType ==
"float") {
361 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
367 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
373 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
382 if (fType ==
"float") {
384 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
385 << fNB <<
", &" <<
OpName <<
"_incx, " <<
OpName <<
"_f_update_gate, &" <<
OpName <<
"_incy);\n";
388 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
393 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
399 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
404 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
407 if (fAttrLinearBeforeReset == 0) {
410 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
412 <<
"_f_hidden_gate, &" <<
OpName <<
"_incy);\n";
416 if (fType ==
"float") {
419 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
425 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
430 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
435 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
440 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
443 if (fAttrLinearBeforeReset == 0) {
446 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
448 <<
"_f_hidden_gate, &" <<
OpName <<
"_incy);\n";
455 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
456 out << SP << SP <<
"size_t offset = seq * " <<
batch_size * fAttrHiddenSize <<
";\n";
462 <<
" + " <<
batch_size * fAttrHiddenSize <<
";\n";
465 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_update_gate + offset, " <<
OpName
466 <<
"_f_update_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_update_gate + gate_offset);\n";
467 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_reset_gate + offset, " <<
OpName
468 <<
"_f_reset_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_reset_gate + gate_offset);\n";
469 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_hidden_gate + offset, " <<
OpName
470 <<
"_f_hidden_gate + offset + " <<
f_seq_size <<
", " <<
OpName <<
"_hidden_gate + gate_offset);\n";
473 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
474 if (fAttrDirection ==
"backward" ||
direction == 1) {
475 out << SP << SP <<
"size_t index = " <<
seq_length - 1 <<
" - seq;\n";
477 out << SP << SP <<
"size_t index = seq;\n";
479 out << SP << SP <<
"int m2 = " <<
batch_size <<
";\n";
485 <<
" + " <<
batch_size * fAttrHiddenSize <<
";\n";
489 out << SP << SP <<
"if (seq == 0) {\n";
490 if (!fNInitial_h.empty()) {
492 if (fType ==
"float") {
493 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
494 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &"
496 <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
497 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
498 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
499 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
501 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
504 if (fType ==
"float") {
505 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
506 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
507 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
509 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
510 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
511 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
512 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
514 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
518 out << SP << SP <<
"} else {\n";
521 if (fAttrDirection ==
"backward") {
522 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
525 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
528 if (fType ==
"float") {
529 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
530 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &"
531 <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
532 <<
OpName <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
533 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
534 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
535 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
541 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
543 if (fType ==
"float") {
544 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
545 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
546 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
550 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
551 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
552 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
558 out << SP << SP <<
"}\n";
561 if (fAttrClip > .0) {
562 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
563 if (fType ==
"float") {
564 out << SP << SP << SP <<
"float z = (" <<
OpName <<
"_update_gate[i] > " << -fAttrClip
565 <<
") ? " <<
OpName <<
"_update_gate[i] : " << -fAttrClip <<
";\n";
567 out << SP << SP << SP <<
OpName <<
"_update_gate[i] = (z < " << fAttrClip
568 <<
") ? z : " << fAttrClip <<
";\n";
569 if (fType ==
"float") {
570 out << SP << SP << SP <<
"float r = (" <<
OpName <<
"_reset_gate[i] > " << -fAttrClip
571 <<
") ? " <<
OpName <<
"_reset_gate[i] : " << -fAttrClip <<
";\n";
573 out << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (r < " << fAttrClip
574 <<
") ? r : " << fAttrClip <<
";\n";
575 out << SP << SP <<
"}\n";
579 if (fAttrActivations[
direction * 2] ==
"Relu") {
580 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
581 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
582 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 0.;\n";
583 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
584 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
585 out << SP << SP <<
"}\n";
586 }
else if (fAttrActivations[
direction * 2] ==
"Tanh") {
587 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
588 if (fType ==
"float") {
589 out << SP << SP << SP <<
"float z = exp(-2 * " <<
OpName <<
"_update_gate[i]);\n";
591 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = (1. - z) / (1. + z);\n";
592 if (fType ==
"float") {
593 out << SP << SP << SP <<
"float r = exp(-2 * " <<
OpName <<
"_reset_gate[i]);\n";
595 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (1. - r) / (1. + r);\n";
596 out << SP << SP <<
"}\n";
597 }
else if (fAttrActivations[
direction * 2] ==
"Sigmoid") {
598 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
599 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 1. / (1. + exp(-"
600 <<
OpName <<
"_update_gate[i]));\n";
601 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 1. / (1. + exp(-"
602 <<
OpName <<
"_reset_gate[i]));\n";
603 out << SP << SP <<
"}\n";
604 }
else if (fAttrActivations[
direction * 2] ==
"Affine") {
605 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
606 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
607 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_update_gate[i] + "
608 << fAttrActivationBeta[
direction * 2] <<
";\n";
609 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
610 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_reset_gate[i] + "
611 << fAttrActivationBeta[
direction * 2] <<
";\n";
612 out << SP << SP <<
"}\n";
613 }
else if (fAttrActivations[
direction * 2] ==
"ScaledTanh") {
614 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
615 if (fType ==
"float") {
616 out << SP << SP << SP <<
"float z = exp(-2 * " << fAttrActivationBeta[
direction * 2]
617 <<
" * "<<
OpName <<
"_update_gate[i]);\n";
619 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
620 << fAttrActivationAlpha[
direction * 2] <<
" * (1. - z) / (1. + z);\n";
621 if (fType ==
"float") {
622 out << SP << SP << SP <<
"float r = exp(-2 * " << fAttrActivationBeta[
direction * 2]
623 <<
" * "<<
OpName <<
"_reset_gate[i]);\n";
625 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
626 << fAttrActivationAlpha[
direction * 2] <<
" * (1. - r) / (1. + r);\n";
627 out << SP << SP <<
"}\n";
628 }
else if (fAttrActivations[
direction * 2] ==
"HardSigmoid") {
629 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
630 if (fType ==
"float") {
631 out << SP << SP << SP <<
"float za = " << fAttrActivationAlpha[
direction * 2] <<
" * "
632 <<
OpName <<
"_update_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
633 out << SP << SP << SP <<
"float zb = (za > 0.) ? za : 0.;\n";
635 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
636 if (fType ==
"float") {
637 out << SP << SP << SP <<
"float ra = " << fAttrActivationAlpha[
direction * 2] <<
" * "
638 <<
OpName <<
"_reset_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
639 out << SP << SP << SP <<
"float rb = (ra > 0.) ? ra : 0.;\n";
641 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
642 out << SP << SP <<
"}\n";
643 }
else if (fAttrActivations[
direction * 2] ==
"LeakyRelu") {
644 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
645 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
646 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
647 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_update_gate[i];\n";
648 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
649 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
650 << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName <<
"_reset_gate[i];\n";
651 out << SP << SP <<
"}\n";
652 }
else if (fAttrActivations[
direction * 2] ==
"ThresholdRelu") {
653 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
654 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < "
655 << fAttrActivationAlpha[
direction * 2] <<
")\n";
656 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 0.;\n";
657 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < "
658 << fAttrActivationAlpha[
direction * 2] <<
")\n";
659 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
660 out << SP << SP <<
"}";
661 }
else if (fAttrActivations[
direction * 2] ==
"Elu") {
662 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
663 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
664 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = "
665 << fAttrActivationAlpha[
direction * 2] <<
" * exp(" <<
OpName <<
"_update_gate[i] - 1.);\n";
666 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
667 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = "
668 << fAttrActivationAlpha[
direction * 2] <<
" * exp(" <<
OpName <<
"_reset_gate[i] - 1.);\n";
669 out << SP << SP <<
"}\n";
670 }
else if (fAttrActivations[
direction * 2] ==
"Softsign") {
671 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
672 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = " <<
OpName
673 <<
"_update_gate[i] / (1. + abs(" <<
OpName <<
"_update_gate[i]));\n";
674 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = " <<
OpName
675 <<
"_reset_gate[i] / (1. + abs(" <<
OpName <<
"_reset_gate[i]));\n";
676 out << SP << SP <<
"}\n";
678 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
679 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = log(1. + exp("
680 <<
OpName <<
"_update_gate[i]));\n";
681 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = log(1. + exp("
682 <<
OpName <<
"_reset_gate[i]));\n";
683 out << SP << SP <<
"}\n";
686 if (fAttrLinearBeforeReset == 0) {
687 out << SP << SP <<
"if (seq == 0) {\n";
688 if (!fNInitial_h.empty()) {
690 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
691 out << SP << SP << SP << SP <<
OpName <<
"_feedback[i] = " <<
OpName
692 <<
"_reset_gate[i + offset] * " <<
OpName <<
"_initial_hidden_state[i];\n";
693 out << SP << SP << SP <<
"}\n";
695 out << SP << SP <<
"} else {\n";
698 if (fAttrDirection ==
"backward") {
699 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
702 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
706 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * " <<
num_directions
709 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
710 out << SP << SP << SP << SP <<
OpName <<
"_feedback[i] = " <<
OpName
711 <<
"_reset_gate[i + offset] * " <<
OpName <<
"_hidden_state[i + previous_offset];\n";
712 out << SP << SP << SP <<
"}\n";
713 out << SP << SP <<
"}\n";
716 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
717 + 2 * fAttrHiddenSize * fAttrHiddenSize;
718 out << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
726 ? 2 * fAttrHiddenSize * fAttrHiddenSize
727 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
728 out << SP << SP <<
"if (seq == 0) {\n";
729 if (!fNInitial_h.empty()) {
731 out << SP << SP << SP
732 <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
733 <<
OpName <<
"_m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
737 out << SP << SP <<
"} else {\n";
740 if (fAttrDirection ==
"backward") {
741 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
744 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
748 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * " <<
num_directions
751 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
756 out << SP << SP <<
"}\n";
761 out << SP << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName
762 <<
"_alpha, tensor_" << fNB <<
" + " <<
rbh_offset <<
", &" <<
OpName <<
"_incx, "
766 out << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
767 out << SP << SP << SP <<
OpName <<
"_feedback[i] *= " <<
OpName <<
"_reset_gate[i + offset];\n";
768 out << SP << SP <<
"}\n";
772 out << SP << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName <<
"_alpha, "
773 <<
OpName <<
"_feedback, &" <<
OpName <<
"_incx, " <<
OpName <<
"_hidden_gate + offset, &"
777 if (fAttrClip > .0) {
778 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
779 if (fType ==
"float") {
780 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_hidden_gate[i] > " << -fAttrClip
781 <<
") ? " <<
OpName <<
"_hidden_gate[i] : " << -fAttrClip <<
";\n";
783 out << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (x < " << fAttrClip <<
") ? x : "
784 << fAttrClip <<
";\n";
785 out << SP << SP <<
"}\n";
789 if (fAttrActivations[
direction * 2 + 1] ==
"Relu") {
790 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
791 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
792 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
793 out << SP << SP <<
"}\n";
794 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Tanh") {
795 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
796 if (fType ==
"float") {
797 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_hidden_gate[i]);\n";
799 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
800 out << SP << SP <<
"}\n";
801 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Sigmoid") {
802 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
803 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 1. / (1. + exp(-" <<
OpName
804 <<
"_hidden_gate[i]));\n";
805 out << SP << SP <<
"}\n";
806 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Affine") {
807 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
808 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
809 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * " <<
OpName <<
"_hidden_gate[i] + "
810 << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
811 out << SP << SP <<
"}\n";
812 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ScaledTanh") {
813 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
814 if (fType ==
"float") {
815 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 2 + 1]
816 <<
" * "<<
OpName <<
"_hidden_gate[i]);\n";
818 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
819 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * (1. - ex) / (1. + ex);\n";
820 out << SP << SP <<
"}\n";
821 }
else if (fAttrActivations[
direction * 2 + 1] ==
"HardSigmoid") {
822 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
823 if (fType ==
"float") {
824 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 2 + 1] <<
" * "
825 <<
OpName <<
"_hidden_gate[i] + " << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
826 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
828 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
829 out << SP << SP <<
"}\n";
830 }
else if (fAttrActivations[
direction * 2 + 1] ==
"LeakyRelu") {
831 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
832 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
833 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
834 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * " <<
OpName <<
"_hidden_gate[i];\n";
835 out << SP << SP <<
"}\n";
836 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ThresholdRelu") {
837 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
838 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < "
839 << fAttrActivationAlpha[
direction * 2 + 1] <<
")\n";
840 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
841 out << SP << SP <<
"}";
842 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Elu") {
843 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
844 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
845 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = "
846 << fAttrActivationAlpha[
direction * 2 + 1] <<
" * exp(" <<
OpName <<
"_hidden_gate[i] - 1.);\n";
847 out << SP << SP <<
"}\n";
848 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Softsign") {
849 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
850 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = " <<
OpName
851 <<
"_hidden_gate[i] / (1. + abs(" <<
OpName <<
"_hidden_gate[i]));\n";
852 out << SP << SP <<
"}\n";
854 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
855 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = log(1. + exp("
856 <<
OpName <<
"_hidden_gate[i]));\n";
857 out << SP << SP <<
"}\n";
861 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
862 out << SP << SP << SP <<
OpName <<
"_hidden_state[i] = ( 1. - " <<
OpName
863 <<
"_update_gate[i]) * " <<
OpName <<
"_hidden_gate[i];\n";
864 out << SP << SP <<
"}\n";
866 out << SP << SP <<
"if (seq == 0) {\n";
867 if (!fNInitial_h.empty()) {
869 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
870 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i + offset] += " <<
OpName
871 <<
"_update_gate[i + offset] * " <<
OpName <<
"_initial_hidden_state[i];\n";
872 out << SP << SP << SP <<
"}\n";
874 out << SP << SP <<
"} else {\n";
877 if (fAttrDirection ==
"backward") {
878 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
881 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
885 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
888 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
889 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i + offset] += " <<
OpName
890 <<
"_update_gate[i + offset] * " <<
OpName <<
"_hidden_state[i + previous_offset];\n";
891 out << SP << SP << SP <<
"}\n";
892 out << SP << SP <<
"}\n";
898 if (!fNSequence_lens.empty()) {
899 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
900 out << SP << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
901 out << SP << SP << SP <<
"if (seq >= tensor_" << fNSequence_lens <<
"[batch]) {\n";
903 out << SP << SP << SP << SP << SP <<
"for (size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
904 out << SP << SP << SP << SP << SP << SP <<
OpName <<
"_hidden_state[seq * "
906 <<
" + batch * " << fAttrHiddenSize <<
" + h] = 0.;\n";
907 out << SP << SP << SP << SP << SP <<
"}\n";
909 out << SP << SP << SP <<
"}\n";
910 out << SP << SP <<
"}\n";
915 if (fAttrLayout == 0) {
916 if (!fNY_h.empty()) {
918 if (fNSequence_lens.empty()) {
920 if (fAttrDirection ==
"backward") {
921 out << SP <<
"std::copy(" <<
OpName <<
"_hidden_state, " <<
OpName <<
"_hidden_state + "
922 <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
926 <<
"_hidden_state + " <<
offset <<
" + " <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
930 <<
"_hidden_state + " << 2 *
yh_size <<
", tensor_" << fNY_h <<
" + " <<
yh_size <<
");\n";
933 if (fAttrDirection ==
"backward") {
934 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
935 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
936 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
937 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + offset);\n";
940 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
941 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
943 <<
" + batch * " << fAttrHiddenSize <<
";\n";
944 out << SP << SP <<
"size_t yh_offset = batch * " << fAttrHiddenSize <<
";\n";
945 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
946 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
950 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
951 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize
952 <<
" + batch * " << fAttrHiddenSize <<
";\n";
953 out << SP << SP <<
"size_t yh_offset = " <<
batch_size * fAttrHiddenSize
954 <<
" + batch * " << fAttrHiddenSize <<
";\n";
955 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
956 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
965 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
966 out << SP << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
968 <<
" + " <<
direction *
batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize <<
";\n";
971 out << SP << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
972 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY <<
" + y_offset);\n";
973 out << SP << SP <<
"}\n";
977 if (!fNY_h.empty()) {
979 if (fAttrDirection ==
"backward") {
980 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
981 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
982 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
983 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
984 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
987 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
988 if (fNSequence_lens.empty()) {
989 out << SP << SP <<
"size_t seq = " <<
seq_length - 1 <<
";\n";
991 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
994 <<
" + batch * " << fAttrHiddenSize <<
";\n";
995 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
996 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
997 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1001 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1002 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * "
1003 << fAttrHiddenSize <<
";\n";
1004 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
" + "
1005 << fAttrHiddenSize <<
";\n";
1006 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1007 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";