1#ifndef TMVA_SOFIE_ROPERATOR_GRU
2#define TMVA_SOFIE_ROPERATOR_GRU
113 if (std::is_same<T, float>::value) {
116 throw std::runtime_error(
117 "TMVA SOFIE Encountered unsupported type parsing a GRU operator");
125 std::vector<ETensorType>
TypeInference(std::vector<ETensorType> )
override;
131 std::vector<std::vector<size_t>>
ShapeInference(std::vector<std::vector<size_t>> )
override;
143 std::string
Generate(std::string )
override;
147 std::vector<std::string>
GetBlasRoutines()
override {
return { std::string(
"Gemm"), std::string(
"Axpy") }; }
162 if (fAttrLayout == 0) {
165 std::vector<std::vector<size_t>>
ret(
171 std::vector<std::vector<size_t>>
ret(
184 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not found in model.");
187 if (fShapeX.size() != 3) {
188 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNX +
" is not of 3 dimensions.");
191 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not found in model.");
194 if (fShapeW.size() != 3) {
195 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNW +
" is not of 3 dimensions.");
198 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not found in model.");
201 if (fShapeR.size() != 3) {
202 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNR +
" is not of 3 dimensions.");
206 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not found in model.");
209 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
210 throw std::runtime_error(
"TMVA SOFIE GRU op input tensor " + fNB +
" is not of 2 or 4 dimensions.");
212 if (fShapeB.size() == 2) {
216 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
217 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
218 if (fType ==
"float") {
222 for (
size_t i = 0; i < 6; i++) {
243 if (!fNSequence_lens.empty()) {
245 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNSequence_lens +
"is not found in model.");
248 if (fShapeSequence_lens.size() != 1) {
249 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNSequence_lens +
" is not of 1 dimension.");
252 if (!fNInitial_h.empty()) {
254 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNInitial_h +
" is not found in model.");
257 if (fShapeInitial_h.size() != 3) {
258 throw std::runtime_error(
"TMVA SOFIE GRU Op input tensor " + fNInitial_h +
" is not of 3 dimensions.");
262 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
267 if (!fNY_h.empty()) {
268 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
278 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
activation +
" not implemented");
281 if (fAttrDirection ==
"reverse")
282 fAttrDirection =
"backward";
283 if (fAttrDirection !=
"forward" && fAttrDirection !=
"backward" && fAttrDirection !=
"reverse" &&
284 fAttrDirection !=
"bidirectional") {
285 throw std::runtime_error(
"TMVA SOFIE - Invalid GRU direction fAttrDirection = " + fAttrDirection);
287 if (3 * fAttrHiddenSize != fShapeW[1]) {
288 throw std::runtime_error(
"TMVA SOFIE - fAttrHiddenSize must be equal to " + std::to_string(fShapeW[1] / 3));
290 if (fAttrLayout > 1) {
291 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " + std::to_string(fAttrLayout) +
292 " must be 0 (timewise) or 1 (batchwise)");
294 if (fAttrLinearBeforeReset > 1) {
295 throw std::runtime_error(
"TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset) +
298 if (fAttrActivations.empty()) {
299 if (fAttrDirection ==
"bidirectional") {
300 fAttrActivations = {
"Sigmoid",
"Tanh",
"Sigmoid",
"Tanh"};
302 fAttrActivations = {
"Sigmoid",
"Tanh"};
309 std::string
opName =
"op_gru_" + fNX;
312 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
313 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
321 if (fAttrLayout != 0) {
341 if (fAttrLayout != 0 || fNY.empty()) {
350 std::stringstream out;
352 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
353 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
357 auto getVec = [&](std::string
const &
name) {
return "tensor_op_gru_" + fNX +
"_" +
name; };
360 if (fAttrLayout == 0) {
361 out << SP << fType <<
" const* " <<
OpName <<
"_input = tensor_" << fNX <<
";\n";
364 out << SP << fType <<
" * " <<
OpName <<
"_input = " <<
getVec(
"input") <<
";\n";
368 out << SP <<
"for(size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
369 out << SP << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
370 out << SP << SP << SP <<
"for(size_t i = 0; i < " <<
input_size <<
"; i++) {\n";
374 out << SP << SP << SP <<
"}\n";
375 out << SP << SP <<
"}\n";
380 if (!fNInitial_h.empty()) {
381 if (fAttrLayout == 0) {
382 out << SP << fType <<
" *" <<
OpName <<
"_initial_hidden_state = " <<
" tensor_" << fNInitial_h <<
";\n";
385 out << SP << fType <<
" * " <<
OpName <<
"_initial_hidden_state = " <<
getVec(
"initial_hidden_state")
388 out << SP << fType <<
" " <<
OpName <<
"_initial_hidden_state["
392 out << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
393 out << SP << SP <<
"for(size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
395 <<
" + batch * " << fAttrHiddenSize <<
" + h] = tensor_" << fNInitial_h <<
"[batch * "
397 out << SP << SP <<
"}\n";
406 out << SP << fType <<
" * " <<
OpName <<
"_f_update_gate = " <<
getVec(
"f_update_gate") <<
";\n";
407 out << SP << fType <<
" * " <<
OpName <<
"_f_reset_gate = " <<
getVec(
"f_reset_gate") <<
";\n";
408 out << SP << fType <<
" * " <<
OpName <<
"_f_hidden_gate = " <<
getVec(
"f_hidden_gate") <<
";\n";
417 out << SP << fType <<
" * " <<
OpName <<
"_update_gate = " <<
getVec(
"update_gate") <<
";\n";
418 out << SP << fType <<
" * " <<
OpName <<
"_reset_gate = " <<
getVec(
"reset_gate") <<
";\n";
419 out << SP << fType <<
" * " <<
OpName <<
"_hidden_gate = " <<
getVec(
"hidden_gate") <<
";\n";
426 if (fAttrLayout == 0 && !fNY.empty()) {
427 out << SP << fType <<
" *" <<
OpName <<
"_hidden_state = tensor_" << fNY <<
";\n";
430 out << SP << fType <<
" * " <<
OpName <<
"_hidden_state = " <<
getVec(
"hidden_state") <<
";\n";
437 out << SP << fType <<
" * " <<
OpName <<
"_feedback = " <<
getVec(
"feedback") <<
";\n";
439 out << SP << fType <<
" " <<
OpName <<
"_feedback[" <<
batch_size * fAttrHiddenSize <<
"] = {0};\n";
442 out << SP <<
"char " <<
OpName <<
"_transA = 'N';\n";
443 out << SP <<
"char " <<
OpName <<
"_transB = 'T';\n";
446 out << SP <<
"int " <<
OpName <<
"_n = " << fAttrHiddenSize <<
";\n";
448 if (fType ==
"float") {
449 out << SP <<
"float " <<
OpName <<
"_alpha = 1.;\n";
450 out << SP <<
"float " <<
OpName <<
"_beta = 0.;\n";
455 out << SP <<
"int " <<
OpName <<
"_incx = 1;\n";
456 out << SP <<
"int " <<
OpName <<
"_incy = 1;\n";
457 out << SP <<
"int " <<
OpName <<
"_feedback_size = " <<
batch_size * fAttrHiddenSize <<
";\n";
461 if (fType ==
"float") {
463 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
466 <<
"_f_update_gate, &" <<
OpName <<
"_n);\n";
469 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
475 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
481 if (fType ==
"float") {
484 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
490 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
496 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
505 if (fType ==
"float") {
507 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
", &"
511 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
516 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
522 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
527 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
530 if (fAttrLinearBeforeReset == 0) {
533 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB
539 if (fType ==
"float") {
542 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
548 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
553 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
558 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
563 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB <<
" + "
566 if (fAttrLinearBeforeReset == 0) {
569 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_" << fNB
578 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
579 out << SP << SP <<
"size_t offset = seq * " <<
batch_size * fAttrHiddenSize <<
";\n";
587 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_update_gate + offset, " <<
OpName <<
"_f_update_gate + offset + "
589 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_reset_gate + offset, " <<
OpName <<
"_f_reset_gate + offset + "
591 out << SP << SP <<
"std::copy(" <<
OpName <<
"_f_hidden_gate + offset, " <<
OpName <<
"_f_hidden_gate + offset + "
595 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
596 if (fAttrDirection ==
"backward" ||
direction == 1) {
597 out << SP << SP <<
"size_t index = " <<
seq_length - 1 <<
" - seq;\n";
599 out << SP << SP <<
"size_t index = seq;\n";
601 out << SP << SP <<
"int m2 = " <<
batch_size <<
";\n";
610 out << SP << SP <<
"if (seq == 0) {\n";
611 if (!fNInitial_h.empty()) {
613 if (fType ==
"float") {
614 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
615 <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &" <<
OpName
616 <<
"_n, " <<
OpName <<
"_initial_hidden_state, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, "
617 <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
618 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
619 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
622 <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
625 if (fType ==
"float") {
626 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
627 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
630 <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
631 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
632 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
635 <<
"_alpha, " <<
OpName <<
"_reset_gate + offset, &" <<
OpName <<
"_n);\n";
639 out << SP << SP <<
"} else {\n";
642 if (fAttrDirection ==
"backward") {
643 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
646 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
649 if (fType ==
"float") {
650 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
651 <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &" <<
OpName <<
"_n, "
653 <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
654 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
655 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
657 <<
", &" <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
661 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
663 if (fType ==
"float") {
664 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
665 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
667 <<
", &" <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
668 <<
OpName <<
"_alpha, " <<
OpName <<
"_update_gate + offset, &" <<
OpName <<
"_n);\n";
669 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
670 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
672 <<
", &" <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
676 out << SP << SP <<
"}\n";
679 if (fAttrClip > .0) {
680 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
681 if (fType ==
"float") {
682 out << SP << SP << SP <<
"float z = (" <<
OpName <<
"_update_gate[i] > " << -fAttrClip <<
") ? " <<
OpName
683 <<
"_update_gate[i] : " << -fAttrClip <<
";\n";
685 out << SP << SP << SP <<
OpName <<
"_update_gate[i] = (z < " << fAttrClip <<
") ? z : " << fAttrClip <<
";\n";
686 if (fType ==
"float") {
687 out << SP << SP << SP <<
"float r = (" <<
OpName <<
"_reset_gate[i] > " << -fAttrClip <<
") ? " <<
OpName
688 <<
"_reset_gate[i] : " << -fAttrClip <<
";\n";
690 out << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (r < " << fAttrClip <<
") ? r : " << fAttrClip <<
";\n";
691 out << SP << SP <<
"}\n";
695 if (fAttrActivations[
direction * 2] ==
"Relu") {
696 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
697 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
698 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 0.;\n";
699 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
700 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
701 out << SP << SP <<
"}\n";
702 }
else if (fAttrActivations[
direction * 2] ==
"Tanh") {
703 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
704 if (fType ==
"float") {
705 out << SP << SP << SP <<
"float z = exp(-2 * " <<
OpName <<
"_update_gate[i]);\n";
707 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = (1. - z) / (1. + z);\n";
708 if (fType ==
"float") {
709 out << SP << SP << SP <<
"float r = exp(-2 * " <<
OpName <<
"_reset_gate[i]);\n";
711 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (1. - r) / (1. + r);\n";
712 out << SP << SP <<
"}\n";
713 }
else if (fAttrActivations[
direction * 2] ==
"Sigmoid") {
714 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
715 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 1. / (1. + exp(-" <<
OpName
716 <<
"_update_gate[i]));\n";
717 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 1. / (1. + exp(-" <<
OpName
718 <<
"_reset_gate[i]));\n";
719 out << SP << SP <<
"}\n";
720 }
else if (fAttrActivations[
direction * 2] ==
"Affine") {
721 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
722 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = " << fAttrActivationAlpha[
direction * 2] <<
" * "
723 <<
OpName <<
"_update_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
724 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = " << fAttrActivationAlpha[
direction * 2] <<
" * "
725 <<
OpName <<
"_reset_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
726 out << SP << SP <<
"}\n";
727 }
else if (fAttrActivations[
direction * 2] ==
"ScaledTanh") {
728 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
729 if (fType ==
"float") {
730 out << SP << SP << SP <<
"float z = exp(-2 * " << fAttrActivationBeta[
direction * 2] <<
" * " <<
OpName
731 <<
"_update_gate[i]);\n";
733 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = " << fAttrActivationAlpha[
direction * 2]
734 <<
" * (1. - z) / (1. + z);\n";
735 if (fType ==
"float") {
736 out << SP << SP << SP <<
"float r = exp(-2 * " << fAttrActivationBeta[
direction * 2] <<
" * " <<
OpName
737 <<
"_reset_gate[i]);\n";
739 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = " << fAttrActivationAlpha[
direction * 2]
740 <<
" * (1. - r) / (1. + r);\n";
741 out << SP << SP <<
"}\n";
742 }
else if (fAttrActivations[
direction * 2] ==
"HardSigmoid") {
743 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
744 if (fType ==
"float") {
745 out << SP << SP << SP <<
"float za = " << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName
746 <<
"_update_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
747 out << SP << SP << SP <<
"float zb = (za > 0.) ? za : 0.;\n";
749 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
750 if (fType ==
"float") {
751 out << SP << SP << SP <<
"float ra = " << fAttrActivationAlpha[
direction * 2] <<
" * " <<
OpName
752 <<
"_reset_gate[i] + " << fAttrActivationBeta[
direction * 2] <<
";\n";
753 out << SP << SP << SP <<
"float rb = (ra > 0.) ? ra : 0.;\n";
755 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
756 out << SP << SP <<
"}\n";
757 }
else if (fAttrActivations[
direction * 2] ==
"LeakyRelu") {
758 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
759 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
760 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = " << fAttrActivationAlpha[
direction * 2] <<
" * "
761 <<
OpName <<
"_update_gate[i];\n";
762 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
763 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = " << fAttrActivationAlpha[
direction * 2] <<
" * "
764 <<
OpName <<
"_reset_gate[i];\n";
765 out << SP << SP <<
"}\n";
766 }
else if (fAttrActivations[
direction * 2] ==
"ThresholdRelu") {
767 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
768 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < " << fAttrActivationAlpha[
direction * 2]
770 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = 0.;\n";
771 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < " << fAttrActivationAlpha[
direction * 2]
773 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = 0.;\n";
774 out << SP << SP <<
"}";
775 }
else if (fAttrActivations[
direction * 2] ==
"Elu") {
776 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
777 out << SP << SP << SP <<
"if (" <<
OpName <<
"_update_gate[i] < 0.)\n";
778 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = " << fAttrActivationAlpha[
direction * 2]
779 <<
" * exp(" <<
OpName <<
"_update_gate[i] - 1.);\n";
780 out << SP << SP << SP <<
"if (" <<
OpName <<
"_reset_gate[i] < 0.)\n";
781 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = " << fAttrActivationAlpha[
direction * 2]
782 <<
" * exp(" <<
OpName <<
"_reset_gate[i] - 1.);\n";
783 out << SP << SP <<
"}\n";
784 }
else if (fAttrActivations[
direction * 2] ==
"Softsign") {
785 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
786 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = " <<
OpName <<
"_update_gate[i] / (1. + abs("
787 <<
OpName <<
"_update_gate[i]));\n";
788 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = " <<
OpName <<
"_reset_gate[i] / (1. + abs("
789 <<
OpName <<
"_reset_gate[i]));\n";
790 out << SP << SP <<
"}\n";
792 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
793 out << SP << SP << SP << SP <<
OpName <<
"_update_gate[i] = log(1. + exp(" <<
OpName <<
"_update_gate[i]));\n";
794 out << SP << SP << SP << SP <<
OpName <<
"_reset_gate[i] = log(1. + exp(" <<
OpName <<
"_reset_gate[i]));\n";
795 out << SP << SP <<
"}\n";
798 if (fAttrLinearBeforeReset == 0) {
799 out << SP << SP <<
"if (seq == 0) {\n";
800 if (!fNInitial_h.empty()) {
802 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
803 out << SP << SP << SP << SP <<
OpName <<
"_feedback[i] = " <<
OpName <<
"_reset_gate[i + offset] * "
804 <<
OpName <<
"_initial_hidden_state[i];\n";
805 out << SP << SP << SP <<
"}\n";
807 out << SP << SP <<
"} else {\n";
810 if (fAttrDirection ==
"backward") {
811 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
814 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
818 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
821 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
822 out << SP << SP << SP << SP <<
OpName <<
"_feedback[i] = " <<
OpName <<
"_reset_gate[i + offset] * " <<
OpName
823 <<
"_hidden_state[i + previous_offset];\n";
824 out << SP << SP << SP <<
"}\n";
825 out << SP << SP <<
"}\n";
828 ? 2 * fAttrHiddenSize * fAttrHiddenSize
829 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
830 out << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName <<
"_n, &"
838 ? 2 * fAttrHiddenSize * fAttrHiddenSize
839 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
840 out << SP << SP <<
"if (seq == 0) {\n";
841 if (!fNInitial_h.empty()) {
843 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
844 <<
"_n, &" <<
OpName <<
"_m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
848 out << SP << SP <<
"} else {\n";
851 if (fAttrDirection ==
"backward") {
852 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
855 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
859 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
862 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &" <<
OpName
863 <<
"_n, &" <<
OpName <<
"_m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
867 out << SP << SP <<
"}\n";
872 out << SP << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName <<
"_alpha, tensor_" << fNB
877 out << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
878 out << SP << SP << SP <<
OpName <<
"_feedback[i] *= " <<
OpName <<
"_reset_gate[i + offset];\n";
879 out << SP << SP <<
"}\n";
883 out << SP << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_feedback_size, &" <<
OpName <<
"_alpha, " <<
OpName
884 <<
"_feedback, &" <<
OpName <<
"_incx, " <<
OpName <<
"_hidden_gate + offset, &" <<
OpName <<
"_incy);\n";
887 if (fAttrClip > .0) {
888 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
889 if (fType ==
"float") {
890 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_hidden_gate[i] > " << -fAttrClip <<
") ? " <<
OpName
891 <<
"_hidden_gate[i] : " << -fAttrClip <<
";\n";
893 out << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (x < " << fAttrClip <<
") ? x : " << fAttrClip <<
";\n";
894 out << SP << SP <<
"}\n";
898 if (fAttrActivations[
direction * 2 + 1] ==
"Relu") {
899 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
900 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
901 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
902 out << SP << SP <<
"}\n";
903 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Tanh") {
904 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
905 if (fType ==
"float") {
906 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_hidden_gate[i]);\n";
908 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
909 out << SP << SP <<
"}\n";
910 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Sigmoid") {
911 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
912 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 1. / (1. + exp(-" <<
OpName
913 <<
"_hidden_gate[i]));\n";
914 out << SP << SP <<
"}\n";
915 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Affine") {
916 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
917 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = " << fAttrActivationAlpha[
direction * 2 + 1]
918 <<
" * " <<
OpName <<
"_hidden_gate[i] + " << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
919 out << SP << SP <<
"}\n";
920 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ScaledTanh") {
921 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
922 if (fType ==
"float") {
923 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 2 + 1] <<
" * " <<
OpName
924 <<
"_hidden_gate[i]);\n";
926 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = " << fAttrActivationAlpha[
direction * 2 + 1]
927 <<
" * (1. - ex) / (1. + ex);\n";
928 out << SP << SP <<
"}\n";
929 }
else if (fAttrActivations[
direction * 2 + 1] ==
"HardSigmoid") {
930 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
931 if (fType ==
"float") {
932 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 2 + 1] <<
" * " <<
OpName
933 <<
"_hidden_gate[i] + " << fAttrActivationBeta[
direction * 2 + 1] <<
";\n";
934 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
936 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
937 out << SP << SP <<
"}\n";
938 }
else if (fAttrActivations[
direction * 2 + 1] ==
"LeakyRelu") {
939 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
940 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
941 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = " << fAttrActivationAlpha[
direction * 2 + 1]
942 <<
" * " <<
OpName <<
"_hidden_gate[i];\n";
943 out << SP << SP <<
"}\n";
944 }
else if (fAttrActivations[
direction * 2 + 1] ==
"ThresholdRelu") {
945 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
946 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < " << fAttrActivationAlpha[
direction * 2 + 1]
948 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = 0.;\n";
949 out << SP << SP <<
"}";
950 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Elu") {
951 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
952 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_gate[i] < 0.)\n";
953 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = " << fAttrActivationAlpha[
direction * 2 + 1]
954 <<
" * exp(" <<
OpName <<
"_hidden_gate[i] - 1.);\n";
955 out << SP << SP <<
"}\n";
956 }
else if (fAttrActivations[
direction * 2 + 1] ==
"Softsign") {
957 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
958 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = " <<
OpName <<
"_hidden_gate[i] / (1. + abs("
959 <<
OpName <<
"_hidden_gate[i]));\n";
960 out << SP << SP <<
"}\n";
962 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
963 out << SP << SP << SP << SP <<
OpName <<
"_hidden_gate[i] = log(1. + exp(" <<
OpName <<
"_hidden_gate[i]));\n";
964 out << SP << SP <<
"}\n";
968 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
969 out << SP << SP << SP <<
OpName <<
"_hidden_state[i] = ( 1. - " <<
OpName <<
"_update_gate[i]) * " <<
OpName
970 <<
"_hidden_gate[i];\n";
971 out << SP << SP <<
"}\n";
973 out << SP << SP <<
"if (seq == 0) {\n";
974 if (!fNInitial_h.empty()) {
976 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
977 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i + offset] += " <<
OpName
978 <<
"_update_gate[i + offset] * " <<
OpName <<
"_initial_hidden_state[i];\n";
979 out << SP << SP << SP <<
"}\n";
981 out << SP << SP <<
"} else {\n";
984 if (fAttrDirection ==
"backward") {
985 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
988 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
992 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
995 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
996 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i + offset] += " <<
OpName
997 <<
"_update_gate[i + offset] * " <<
OpName <<
"_hidden_state[i + previous_offset];\n";
998 out << SP << SP << SP <<
"}\n";
999 out << SP << SP <<
"}\n";
1005 if (!fNSequence_lens.empty()) {
1006 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
1007 out << SP << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1008 out << SP << SP << SP <<
"if (seq >= tensor_" << fNSequence_lens <<
"[batch]) {\n";
1010 out << SP << SP << SP << SP << SP <<
"for (size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
1011 out << SP << SP << SP << SP << SP << SP <<
OpName <<
"_hidden_state[seq * "
1013 <<
" + batch * " << fAttrHiddenSize <<
" + h] = 0.;\n";
1014 out << SP << SP << SP << SP << SP <<
"}\n";
1016 out << SP << SP << SP <<
"}\n";
1017 out << SP << SP <<
"}\n";
1022 if (fAttrLayout == 0) {
1023 if (!fNY_h.empty()) {
1025 if (fNSequence_lens.empty()) {
1027 if (fAttrDirection ==
"backward") {
1028 out << SP <<
"std::copy(" <<
OpName <<
"_hidden_state, " <<
OpName <<
"_hidden_state + " <<
yh_size
1029 <<
", tensor_" << fNY_h <<
");\n";
1032 out << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + " <<
offset <<
", " <<
OpName
1033 <<
"_hidden_state + " <<
offset <<
" + " <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
1037 <<
"_hidden_state + " << 2 *
yh_size <<
", tensor_" << fNY_h <<
" + " <<
yh_size <<
");\n";
1040 if (fAttrDirection ==
"backward") {
1041 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1042 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
1043 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1044 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + offset);\n";
1047 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1048 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
1050 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1051 out << SP << SP <<
"size_t yh_offset = batch * " << fAttrHiddenSize <<
";\n";
1052 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1053 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1057 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1058 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize
1060 out << SP << SP <<
"size_t yh_offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * "
1061 << fAttrHiddenSize <<
";\n";
1062 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1063 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1072 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
1073 out << SP << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1078 out << SP << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1079 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY <<
" + y_offset);\n";
1080 out << SP << SP <<
"}\n";
1084 if (!fNY_h.empty()) {
1086 if (fAttrDirection ==
"backward") {
1087 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1088 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
1089 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1090 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1091 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1094 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1095 if (fNSequence_lens.empty()) {
1096 out << SP << SP <<
"size_t seq = " <<
seq_length - 1 <<
";\n";
1098 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
1101 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1102 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1103 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1104 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
1108 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1109 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize
1111 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
" + "
1112 << fAttrHiddenSize <<
";\n";
1113 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1114 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
std::vector< size_t > GetTensorShape(const std::string &name) const
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
bool CheckIfTensorAlreadyExist(std::string tensor_name)
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
ETensorType GetTensorType(std::string name) const
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Gated Recurrent Unit operator.
std::vector< size_t > fShapeY
Shape of the output.
std::string fNX
Name of the input.
std::string fType
Type of the tensors.
size_t fAttrLayout
Data layout.
std::string fAttrDirection
Direction of processing.
std::string fNR
Name of the recurrence.
float fAttrClip
Clip threshold.
std::vector< float > fAttrActivationBeta
Scaling values used by some activation functions.
std::string fNY
Name of the output.
std::string fNY_h
Name of the last sequence of the output.
std::string fNSequence_lens
Name of the length of the sequences.
std::string fNB
Name of the bias.
std::vector< std::string > fAttrActivations
Activation functions.
void Initialize(RModel &) override
Initialize the model.
ROperator_GRU(std::vector< float > activation_alpha, std::vector< float > activation_beta, std::vector< std::string > activations, float clip, std::string direction, size_t hidden_size, size_t layout, size_t linear_before_reset, std::string nameX, std::string nameW, std::string nameR, std::string nameB, std::string nameSequence_lens, std::string nameInitial_h, std::string nameY, std::string nameY_h)
Constructor of ROperator_GRU from the attributes.
size_t fAttrHiddenSize
Number of the hidden layers.
std::string Generate(std::string) override
Generate the inference code.
std::vector< float > fAttrActivationAlpha
Scaling values used by some activation functions.
std::vector< size_t > fShapeR
Shape of the recurrence.
std::string fNW
Name of the weights.
std::vector< size_t > fShapeX
Shape of the input.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >) override
Infers the shape of the output tensors.
std::vector< size_t > fShapeInitial_h
Shape of the initial value of hidden states.
std::vector< size_t > fShapeSequence_lens
Shape of the length of the sequences.
std::vector< ETensorType > TypeInference(std::vector< ETensorType >) override
Infers the type of the output tensors.
std::vector< size_t > fShapeY_h
Shape of the last sequence of the output.
size_t fAttrLinearBeforeReset
Linear layer before the reset gate.
std::vector< size_t > fShapeB
Shape of the bias.
std::string fNInitial_h
Name of the initial value of the hidden states.
std::vector< size_t > fShapeW
Shape of the weights.
ROperator_GRU()
Default constructor of ROperator_GRU.
std::vector< std::string > GetBlasRoutines() override
Returns the blas routines needed to compile the generated code.
std::vector< std::string_view > fInputTensorNames
std::vector< std::string_view > fOutputTensorNames
ETensorType ConvertStringToType(std::string type)