17-> std::vector<std::vector<size_t>> {
18 size_t num_directions =
input[1][0];
19 size_t hidden_size =
input[1][1] / 4;
21 size_t seq_length =
input[0][0];
22 size_t batch_size =
input[0][1];
23 std::vector<std::vector<size_t>> ret(
24 {{seq_length, num_directions, batch_size, hidden_size},
25 {num_directions, batch_size, hidden_size},
26 {num_directions, batch_size, hidden_size}});
29 size_t batch_size =
input[0][0];
30 size_t seq_length =
input[0][1];
31 std::vector<std::vector<size_t>> ret(
32 {{batch_size, seq_length, num_directions, hidden_size},
33 {batch_size, num_directions, hidden_size},
34 {batch_size, num_directions, hidden_size}});
44 if (!model.CheckIfTensorAlreadyExist(
fNX)) {
45 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNX +
" is not found in model.");
49 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNX +
" is not of 3 dimensions.");
51 if (!model.CheckIfTensorAlreadyExist(
fNW)) {
52 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNW +
" is not found in model.");
56 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNW +
" is not of 3 dimensions.");
58 if (!model.CheckIfTensorAlreadyExist(
fNR)) {
59 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNR +
" is not found in model.");
63 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
fNR +
" is not of 3 dimensions.");
66 if (!model.CheckIfTensorAlreadyExist(
fNB)) {
67 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " +
fNB +
" is not found in model.");
71 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " +
fNB +
" is not of 2 or 5 dimensions.");
75 auto original_data = model.GetInitializedTensorData(
fNB);
76 size_t num_directions =
fShapeW[0];
79 if (
fType ==
"float") {
80 float *original_bias =
static_cast<float*
>(original_data.get());
81 float *new_bias =
new float[4 * num_directions * seq_length * batch_size *
fAttrHiddenSize];
82 for (
size_t gate = 0; gate < 4; gate++) {
84 for (
size_t direction = 0; direction < num_directions; direction++) {
89 for (
size_t seq = 0; seq < seq_length; seq++) {
90 for (
size_t batch = 0; batch < batch_size; batch++) {
91 size_t bias_offset = gate * num_directions * seq_length * batch_size *
fAttrHiddenSize
99 std::vector<size_t> new_bias_shape = {4, num_directions, seq_length, batch_size,
fAttrHiddenSize};
100 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<
float[]>());
101 model.UpdateInitializedTensor(
fNB, model.GetTensorType(
fNB), new_bias_shape, new_bias_ptr);
108 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
110 "is not found in model.");
114 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
116 " is not of 1 dimension.");
120 if (!model.CheckIfTensorAlreadyExist(
fNInitial_h)) {
121 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
126 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
131 if (!model.CheckIfTensorAlreadyExist(
fNInitial_c)) {
132 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
137 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
142 if (!model.CheckIfTensorAlreadyExist(
fNP)) {
143 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " +
fNP +
" is not found in model.");
147 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " +
fNP +
" is not of 2 or 4 dimensions.");
151 auto original_data = model.GetInitializedTensorData(
fNP);
152 size_t num_directions =
fShapeW[0];
154 if (
fType ==
"float") {
155 float *original_p =
static_cast<float*
>(original_data.get());
156 float *new_p =
new float[num_directions * 3 * batch_size *
fAttrHiddenSize];
157 for (
size_t direction = 0; direction < num_directions; direction++) {
158 for (
size_t gate = 0; gate < 3; gate++) {
160 for (
size_t batch = 0; batch < batch_size; batch++) {
163 std::copy(original_p + p_offset, original_p + p_offset +
fAttrHiddenSize,
168 std::vector<size_t> new_p_shape = {num_directions, 3, batch_size,
fAttrHiddenSize};
169 std::shared_ptr<void> new_p_ptr(new_p, std::default_delete<
float[]>());
170 model.UpdateInitializedTensor(
fNP, model.GetTensorType(
fNP), new_p_shape, new_p_ptr);
177 if (!model.CheckIfTensorAlreadyExist(
fNY)) {
178 model.AddIntermediateTensor(
fNY, model.GetTensorType(
fNX),
fShapeY);
181 if (!
fNY_h.empty()) {
183 if (!model.CheckIfTensorAlreadyExist(
fNY_h)) {
187 if (!
fNY_c.empty()) {
189 if (!model.CheckIfTensorAlreadyExist(
fNY_c)) {
195 if (activation !=
"Relu" && activation !=
"Tanh" &&
196 activation !=
"Sigmoid" && activation !=
"Affine" &&
197 activation !=
"LeakyRelu" && activation !=
"ThresholdRelu" &&
198 activation !=
"ScaledTanh" && activation !=
"HardSigmoid" &&
199 activation !=
"Elu" && activation !=
"Softsign" &&
200 activation !=
"Softplus") {
201 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
202 activation +
" not implemented");
207 throw std::runtime_error(
208 "TMVA SOFIE - Invalid LSTM direction fAttrDirection = " +
212 throw std::runtime_error(
213 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
214 std::to_string(
fShapeW[1] / 4));
217 throw std::runtime_error(
219 +
" must be 0 or 1.");
222 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " +
224 " must be 0 (timewise) or 1 (batchwise)");
285 OpName =
"op_" + OpName;
286 std::stringstream out;
290 size_t input_size =
fShapeX[2];
291 size_t num_directions =
fShapeW[0];
295 out <<
SP <<
fType <<
" *" << OpName <<
"_input = tensor_" <<
fNX <<
";\n";
298 out <<
SP <<
fType <<
" * " << OpName <<
"_input = fVec_" << OpName <<
"_input.data();\n";
300 out <<
SP <<
fType <<
" " << OpName <<
"_input[" << seq_length * batch_size * input_size <<
"] = {0};\n";
302 out <<
SP <<
"for(size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
303 out <<
SP <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
304 out <<
SP <<
SP <<
SP <<
"for(size_t i = 0; i < " << input_size <<
"; i++) {\n";
305 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input[seq * " << batch_size * input_size
306 <<
" + batch * " << input_size <<
" + i] = " <<
"tensor_" <<
fNX <<
"[batch * "
307 << seq_length * input_size <<
" + seq * " << input_size <<
" + i];\n";
308 out <<
SP <<
SP <<
SP <<
"}\n";
309 out <<
SP <<
SP <<
"}\n";
316 out <<
SP <<
fType <<
" *" << OpName <<
"_initial_hidden_state = " <<
" tensor_"
320 out <<
SP <<
fType <<
" * " << OpName <<
"_initial_hidden_state = fVec_" << OpName
321 <<
"_initial_hidden_state.data();\n";
323 out <<
SP <<
fType <<
" " << OpName <<
"_initial_hidden_state[" << num_directions * batch_size *
326 for (
size_t direction = 0; direction < num_directions; direction++) {
327 out <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
329 out <<
SP <<
SP <<
SP << OpName <<
"_initial_hidden_state["
333 out <<
SP <<
SP <<
"}\n";
342 out <<
SP <<
fType <<
" *" << OpName <<
"_initial_cell_state = " <<
" tensor_"
346 out <<
SP <<
fType <<
" * " << OpName <<
"_initial_cell_state = fVec_" << OpName
347 <<
"_initial_cell_state.data();\n";
349 out <<
SP <<
fType <<
" " << OpName <<
"_initial_cell_state[" << num_directions * batch_size *
352 for (
size_t direction = 0; direction < num_directions; direction++) {
353 out <<
SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
355 out <<
SP <<
SP <<
SP << OpName <<
"_initial_cell_state["
359 out <<
SP <<
SP <<
"}\n";
368 out <<
SP <<
fType <<
" * " << OpName <<
"_ff_input_gate = fVec_" << OpName <<
"_ff_input_gate.data();\n";
369 out <<
SP <<
fType <<
" * " << OpName <<
"_ff_output_gate = fVec_" << OpName <<
"_ff_output_gate.data();\n";
370 out <<
SP <<
fType <<
" * " << OpName <<
"_ff_cell_gate = fVec_" << OpName <<
"_ff_cell_gate.data();\n";
372 out <<
SP <<
fType <<
" * " << OpName <<
"_ff_forget_gate = fVec_" << OpName <<
"_ff_forget_gate.data();\n";
375 out <<
SP <<
fType <<
" " << OpName <<
"_ff_input_gate[" << ff_size <<
"] = {0};\n";
376 out <<
SP <<
fType <<
" " << OpName <<
"_ff_output_gate[" << ff_size <<
"] = {0};\n";
377 out <<
SP <<
fType <<
" " << OpName <<
"_ff_cell_gate[" << ff_size <<
"] = {0};\n";
379 out <<
SP <<
fType <<
" " << OpName <<
"_ff_forget_gate[" << ff_size <<
"] = {0};\n";
383 size_t hidden_state_size = seq_length * num_directions * batch_size *
fAttrHiddenSize;
385 out <<
SP <<
fType <<
" * " << OpName <<
"_input_gate = fVec_" << OpName <<
"_input_gate.data();\n";
386 out <<
SP <<
fType <<
" * " << OpName <<
"_output_gate = fVec_" << OpName <<
"_output_gate.data();\n";
387 out <<
SP <<
fType <<
" * " << OpName <<
"_cell_gate = fVec_" << OpName <<
"_cell_gate.data();\n";
389 out <<
SP <<
fType <<
" * " << OpName <<
"_forget_gate = fVec_" << OpName <<
"_forget_gate.data();\n";
392 out <<
SP <<
fType <<
" " << OpName <<
"_input_gate[" << hidden_state_size <<
"] = {0};\n";
393 out <<
SP <<
fType <<
" " << OpName <<
"_output_gate[" << hidden_state_size <<
"] = {0};\n";
394 out <<
SP <<
fType <<
" " << OpName <<
"_cell_gate[" << hidden_state_size <<
"] = {0};\n";
396 out <<
SP <<
fType <<
" " << OpName <<
"_forget_gate[" << hidden_state_size <<
"] = {0};\n";
401 out <<
SP <<
fType <<
" * " << OpName <<
"_cell_state = fVec_" << OpName <<
"_cell_state.data();\n";
402 out <<
SP <<
fType <<
" * " << OpName <<
"_new_cell_state = fVec_" << OpName <<
"_new_cell_state.data();\n";
404 out <<
SP <<
fType <<
" " << OpName <<
"_cell_state[" << hidden_state_size <<
"] = {0};\n";
405 out <<
SP <<
fType <<
" " << OpName <<
"_new_cell_state[" << hidden_state_size <<
"] = {0};\n";
410 out <<
SP <<
fType <<
" *" << OpName <<
"_hidden_state = tensor_" <<
fNY <<
";\n";
413 out <<
SP <<
fType <<
" * " << OpName <<
"_hidden_state = fVec_" << OpName <<
"_hidden_state.data();\n";
415 out <<
SP <<
fType <<
" " << OpName <<
"_hidden_state[" << hidden_state_size <<
"] = {0};\n";
419 out <<
SP <<
"char " << OpName <<
"_transA = 'N';\n";
420 out <<
SP <<
"char " << OpName <<
"_transB = 'T';\n";
421 out <<
SP <<
"int " << OpName <<
"_m = " << seq_length * batch_size <<
";\n";
423 out <<
SP <<
"int " << OpName <<
"_k = " << input_size <<
";\n";
424 if (
fType ==
"float") {
425 out <<
SP <<
fType <<
" " << OpName <<
"_alpha = 1.;\n";
426 out <<
SP <<
fType <<
" " << OpName <<
"_beta = 0.;\n";
429 out <<
SP <<
"int " << OpName <<
"_bias_size = " << seq_length * batch_size *
fAttrHiddenSize <<
";\n";
430 out <<
SP <<
"int " << OpName <<
"_incx = 1;\n";
431 out <<
SP <<
"int " << OpName <<
"_incy = 1;\n";
434 for (
size_t direction = 0; direction < num_directions; direction++) {
435 if (direction == 0) {
436 if (
fType ==
"float") {
438 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
439 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
440 <<
fNW <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &" << OpName <<
"_k, &"
441 << OpName <<
"_beta, " << OpName <<
"_ff_input_gate, &" << OpName <<
"_n);\n";
444 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
445 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
446 <<
fNW <<
" + " << wo_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
447 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_ff_output_gate, &" << OpName <<
"_n);\n";
450 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
451 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
452 <<
fNW <<
" + " << wc_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
453 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_ff_cell_gate, &" << OpName <<
"_n);\n";
456 if (
fType ==
"float") {
459 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
460 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
461 <<
fNW <<
" + " << wi_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
462 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_ff_input_gate, &" << OpName <<
"_n);\n";
465 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
466 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
467 <<
fNW <<
" + " << wo_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
468 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_ff_output_gate, &" << OpName <<
"_n);\n";
471 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
472 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
473 <<
fNW <<
" + " << wc_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
474 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_ff_cell_gate, &" << OpName <<
"_n);\n";
479 if (direction == 0) {
480 if (
fType ==
"float") {
482 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
483 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
484 <<
fNW <<
" + " << wf_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
485 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_ff_forget_gate, &" << OpName <<
"_n);\n";
488 if (
fType ==
"float") {
490 out <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
491 << OpName <<
"_n, &" << OpName <<
"_m, &" << OpName <<
"_k, &" << OpName <<
"_alpha, tensor_"
492 <<
fNW <<
" + " << wf_offset <<
", &" << OpName <<
"_k, " << OpName <<
"_input, &"
493 << OpName <<
"_k, &" << OpName <<
"_beta, " << OpName <<
"_ff_forget_gate, &" << OpName <<
"_n);\n";
500 if (direction == 0) {
501 if (
fType ==
"float") {
503 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
504 <<
fNB <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_input_gate, &" << OpName <<
"_incy);\n";
507 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
508 <<
fNB <<
" + " << bo_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_output_gate, &"
509 << OpName <<
"_incy);\n";
512 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
513 <<
fNB <<
" + " << bc_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_cell_gate, &"
514 << OpName <<
"_incy);\n";
517 if (
fType ==
"float") {
520 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
521 <<
fNB <<
" + " << bi_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_input_gate, &"
522 << OpName <<
"_incy);\n";
526 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
527 <<
fNB <<
" + " << bo_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_output_gate, &"
528 << OpName <<
"_incy);\n";
530 size_t bc_offset = 4 * num_directions * seq_length * batch_size *
fAttrHiddenSize
532 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
533 <<
fNB <<
" + " << bc_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_cell_gate, &"
534 << OpName <<
"_incy);\n";
539 if (direction == 0) {
540 if (
fType ==
"float") {
542 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
543 <<
fNB <<
" + " << bo_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_forget_gate, &"
544 << OpName <<
"_incy);\n";
547 if (
fType ==
"float") {
550 out <<
SP <<
"BLAS::saxpy_(&" << OpName <<
"_bias_size, &" << OpName <<
"_alpha, tensor_"
551 <<
fNB <<
" + " << bo_offset <<
", &" << OpName <<
"_incx, " << OpName <<
"_ff_forget_gate, &"
552 << OpName <<
"_incy);\n";
560 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
562 if (direction == 0) {
563 out <<
SP <<
SP <<
"size_t gate_offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
566 out <<
SP <<
SP <<
"size_t gate_offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
570 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_ff_input_gate + ff_offset, " << OpName
571 <<
"_ff_input_gate + ff_offset + " << ff_seq_size <<
", " << OpName <<
"_input_gate + gate_offset);\n";
572 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_ff_output_gate + ff_offset, " << OpName
573 <<
"_ff_output_gate + ff_offset + " << ff_seq_size <<
", " << OpName <<
"_output_gate + gate_offset);\n";
574 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_ff_cell_gate + ff_offset, " << OpName
575 <<
"_ff_cell_gate + ff_offset + " << ff_seq_size <<
", " << OpName <<
"_cell_gate + gate_offset);\n";
577 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_ff_forget_gate + ff_offset, " << OpName
578 <<
"_ff_forget_gate + ff_offset + " << ff_seq_size <<
", " << OpName <<
"_forget_gate + gate_offset);\n";
582 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
584 out <<
SP <<
SP <<
"size_t index = " << seq_length - 1 <<
" - seq;\n";
586 out <<
SP <<
SP <<
"size_t index = seq;\n";
588 out <<
SP <<
SP <<
"int m2 = " << batch_size <<
";\n";
589 if (direction == 0) {
590 out <<
SP <<
SP <<
"size_t offset = index * " << num_directions * batch_size *
fAttrHiddenSize
593 out <<
SP <<
SP <<
"size_t offset = index * " << num_directions * batch_size *
fAttrHiddenSize
598 out <<
SP <<
SP <<
"if (seq == 0) {\n";
600 if (direction == 0) {
601 if (
fType ==
"float") {
602 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
603 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
", &"
604 << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName <<
"_n, &" << OpName
605 <<
"_alpha, " << OpName <<
"_input_gate + offset, &" << OpName <<
"_n);\n";
607 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
608 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
609 << ro_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
610 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_output_gate + offset, &" << OpName <<
"_n);\n";
612 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
613 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
614 << rc_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
615 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_cell_gate + offset, &" << OpName <<
"_n);\n";
618 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
619 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
620 << rf_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
621 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_forget_gate + offset, &" << OpName <<
"_n);\n";
625 if (
fType ==
"float") {
627 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
628 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
629 << ri_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
630 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_input_gate + offset, &" << OpName <<
"_n);\n";
632 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
633 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
634 << ro_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
635 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_output_gate + offset, &" << OpName <<
"_n);\n";
637 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
638 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
639 << rc_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
640 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_cell_gate + offset, &" << OpName <<
"_n);\n";
643 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
644 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
645 << rf_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_initial_hidden_state, &" << OpName
646 <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_forget_gate + offset, &" << OpName <<
"_n);\n";
651 out <<
SP <<
SP <<
"} else {\n";
653 if (direction == 0) {
655 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
658 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
661 if (
fType ==
"float") {
662 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
663 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
", &"
664 << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &" << OpName <<
"_n, &"
665 << OpName <<
"_alpha, " << OpName <<
"_input_gate + offset, &" << OpName <<
"_n);\n";
667 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
668 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
669 << ro_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
670 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_output_gate + offset, &"
671 << OpName <<
"_n);\n";
673 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
674 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
675 << rc_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
676 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_cell_gate + offset, &"
677 << OpName <<
"_n);\n";
680 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
681 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
682 << rf_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
683 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_forget_gate + offset, &"
684 << OpName <<
"_n);\n";
688 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
690 if (
fType ==
"float") {
692 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
693 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
694 << ri_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
695 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_input_gate + offset, &"
696 << OpName <<
"_n);\n";
698 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
699 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
700 << ro_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
701 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_output_gate + offset, &"
702 << OpName <<
"_n);\n";
704 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
705 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
706 << rc_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
707 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_cell_gate + offset, &"
708 << OpName <<
"_n);\n";
711 out <<
SP <<
SP <<
SP <<
"BLAS::sgemm_(&" << OpName <<
"_transB, &" << OpName <<
"_transA, &"
712 << OpName <<
"_n, &m2, &" << OpName <<
"_n, &" << OpName <<
"_alpha, tensor_" <<
fNR <<
" + "
713 << rf_offset <<
", &" << OpName <<
"_n, " << OpName <<
"_hidden_state + previous_offset, &"
714 << OpName <<
"_n, &" << OpName <<
"_alpha, " << OpName <<
"_forget_gate + offset, &"
715 << OpName <<
"_n);\n";
719 out <<
SP <<
SP <<
"}\n";
723 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
724 if (
fType ==
"float") {
725 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_cell_gate[i] > " << -
fAttrClip <<
") ? "
726 << OpName <<
"_cell_gate[i] : " << -
fAttrClip <<
";\n";
728 out <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = (x < " <<
fAttrClip <<
") ? x : "
730 out <<
SP <<
SP <<
"}\n";
734 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
735 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_cell_gate[i] < 0.)\n";
736 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = 0.;\n";
737 out <<
SP <<
SP <<
"}\n";
739 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
740 if (
fType ==
"float") {
741 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_cell_gate[i]);\n";
743 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = (1. - ex) / (1. + ex);\n";
744 out <<
SP <<
SP <<
"}\n";
746 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
747 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = 1. / (1. + exp(-" << OpName
748 <<
"_cell_gate[i]));\n";
749 out <<
SP <<
SP <<
"}\n";
751 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
752 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = "
755 out <<
SP <<
SP <<
"}\n";
757 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
758 if (
fType ==
"float") {
760 <<
" * "<< OpName <<
"_cell_gate[i]);\n";
762 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = "
764 out <<
SP <<
SP <<
"}\n";
766 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
767 if (
fType ==
"float") {
770 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
772 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = (b < 1.) ? b : 1.;\n";
773 out <<
SP <<
SP <<
"}\n";
775 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
776 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_cell_gate[i] < 0.)\n";
777 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = "
779 out <<
SP <<
SP <<
"}\n";
781 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
782 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_cell_gate[i] < "
784 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = 0.;\n";
785 out <<
SP <<
SP <<
"}";
787 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
788 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_cell_gate[i] < 0.)\n";
789 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = "
791 out <<
SP <<
SP <<
"}\n";
793 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
794 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = " << OpName
795 <<
"_cell_gate[i] / (1. + abs(" << OpName <<
"_cell_gate[i]));\n";
796 out <<
SP <<
SP <<
"}\n";
798 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
799 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_gate[i] = log(1. + exp("
800 << OpName <<
"_cell_gate[i]));\n";
801 out <<
SP <<
SP <<
"}\n";
807 out <<
SP <<
SP <<
"if (seq == 0) {\n";
809 if (direction == 0) {
810 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
811 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i + offset] += tensor_" <<
fNP
812 <<
"[i] * " << OpName <<
"_initial_cell_state[i];\n";
813 out <<
SP <<
SP <<
SP <<
"}\n";
816 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
817 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i + offset] += tensor_" <<
fNP
818 <<
"[i + " << pf_offset <<
"] * " << OpName <<
"_initial_cell_state[i];\n";
819 out <<
SP <<
SP <<
SP <<
"}\n";
824 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
825 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i + offset] += tensor_" <<
fNP
826 <<
"[i + " << pi_offset <<
"] * " << OpName <<
"_initial_cell_state[i + " << initial_c_offset
828 out <<
SP <<
SP <<
SP <<
"}\n";
831 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
832 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i + offset] += tensor_" <<
fNP
833 <<
"[i + " << pf_offset <<
"] * " << OpName <<
"_initial_cell_state[i + " << initial_c_offset
835 out <<
SP <<
SP <<
SP <<
"}\n";
839 out <<
SP <<
SP <<
"} else {\n";
840 if (direction == 0) {
842 out <<
SP <<
SP <<
SP <<
"size_t c_offset = (index + 1) * "
845 out <<
SP <<
SP <<
SP <<
"size_t c_offset = (seq - 1) * "
848 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
849 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i + offset] += tensor_" <<
fNP
850 <<
"[i] * " << OpName <<
"_cell_state[i + c_offset];\n";
851 out <<
SP <<
SP <<
SP <<
"}\n";
854 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
855 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i + offset] += tensor_" <<
fNP
856 <<
"[i + " << pf_offset <<
"] * " << OpName <<
"_cell_state[i + c_offset];\n";
857 out <<
SP <<
SP <<
SP <<
"}\n";
861 out <<
SP <<
SP <<
SP <<
"size_t c_offset = (index + 1) * "
863 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
864 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i + offset] += tensor_" <<
fNP
865 <<
"[i + " << pi_offset <<
"] * " << OpName <<
"_cell_state[i + c_offset];\n";
866 out <<
SP <<
SP <<
SP <<
"}\n";
869 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
870 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i + offset] += tensor_" <<
fNP
871 <<
"[i + " << pf_offset <<
"] * " << OpName <<
"_cell_state[i + c_offset];\n";
872 out <<
SP <<
SP <<
SP <<
"}\n";
875 out <<
SP <<
SP <<
"}\n";
880 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
881 if (
fType ==
"float") {
882 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_input_gate[i] > " << -
fAttrClip <<
") ? "
883 << OpName <<
"_input_gate[i] : " << -
fAttrClip <<
";\n";
885 out <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = (x < " <<
fAttrClip <<
") ? x : "
887 out <<
SP <<
SP <<
"}\n";
891 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
892 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_input_gate[i] < 0.)\n";
893 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = 0.;\n";
894 out <<
SP <<
SP <<
"}\n";
896 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
897 if (
fType ==
"float") {
898 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_input_gate[i]);\n";
900 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = (1. - ex) / (1. + ex);\n";
901 out <<
SP <<
SP <<
"}\n";
903 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
904 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = 1. / (1. + exp(-" << OpName
905 <<
"_input_gate[i]));\n";
906 out <<
SP <<
SP <<
"}\n";
908 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
909 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = "
912 out <<
SP <<
SP <<
"}\n";
914 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
915 if (
fType ==
"float") {
917 <<
" * "<< OpName <<
"_input_gate[i]);\n";
919 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = "
921 out <<
SP <<
SP <<
"}\n";
923 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
924 if (
fType ==
"float") {
927 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
929 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = (b < 1.) ? b : 1.;\n";
930 out <<
SP <<
SP <<
"}\n";
932 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
933 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_input_gate[i] < 0.)\n";
934 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = "
936 out <<
SP <<
SP <<
"}\n";
938 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
939 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_input_gate[i] < "
941 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = 0.;\n";
942 out <<
SP <<
SP <<
"}";
944 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
945 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_input_gate[i] < 0.)\n";
946 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = "
948 out <<
SP <<
SP <<
"}\n";
950 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
951 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = " << OpName
952 <<
"_input_gate[i] / (1. + abs(" << OpName <<
"_input_gate[i]));\n";
953 out <<
SP <<
SP <<
"}\n";
955 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
956 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_input_gate[i] = log(1. + exp("
957 << OpName <<
"_input_gate[i]));\n";
958 out <<
SP <<
SP <<
"}\n";
964 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
965 if (
fType ==
"float") {
966 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_forget_gate[i] > "
969 out <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = (x < " <<
fAttrClip
971 out <<
SP <<
SP <<
"}\n";
975 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
976 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_forget_gate[i] < 0.)\n";
977 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = 0.;\n";
978 out <<
SP <<
SP <<
"}\n";
980 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
981 if (
fType ==
"float") {
982 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_forget_gate[i]);\n";
984 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = (1. - ex) / (1. + ex);\n";
985 out <<
SP <<
SP <<
"}\n";
987 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
988 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = 1. / (1. + exp(-"
989 << OpName <<
"_forget_gate[i]));\n";
990 out <<
SP <<
SP <<
"}\n";
992 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
993 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = "
996 out <<
SP <<
SP <<
"}\n";
998 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
999 if (
fType ==
"float") {
1001 <<
" * "<< OpName <<
"_forget_gate[i]);\n";
1003 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = "
1005 out <<
SP <<
SP <<
"}\n";
1007 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1008 if (
fType ==
"float") {
1011 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
1013 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = (b < 1.) ? b : 1.;\n";
1014 out <<
SP <<
SP <<
"}\n";
1016 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1017 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_forget_gate[i] < 0.)\n";
1018 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = "
1020 out <<
SP <<
SP <<
"}\n";
1022 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1023 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_forget_gate[i] < "
1025 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = 0.;\n";
1026 out <<
SP <<
SP <<
"}";
1028 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1029 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_forget_gate[i] < 0.)\n";
1030 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = "
1032 out <<
SP <<
SP <<
"}\n";
1034 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1035 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = " << OpName
1036 <<
"_forget_gate[i] / (1. + abs(" << OpName <<
"_forget_gate[i]));\n";
1037 out <<
SP <<
SP <<
"}\n";
1039 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1040 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_forget_gate[i] = log(1. + exp("
1041 << OpName <<
"_forget_gate[i]));\n";
1042 out <<
SP <<
SP <<
"}\n";
1047 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1048 out <<
SP <<
SP <<
SP << OpName <<
"_cell_state[i] = " << OpName <<
"_input_gate[i] * "
1049 << OpName <<
"_cell_gate[i];\n";
1050 out <<
SP <<
SP <<
"}\n";
1053 out <<
SP <<
SP <<
"if (seq == 0) {\n";
1056 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1057 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_state[i + offset] += "
1058 << OpName <<
"_forget_gate[i + offset] * " << OpName <<
"_initial_cell_state[i];\n";
1059 out <<
SP <<
SP <<
SP <<
"}\n";
1061 out <<
SP <<
SP <<
"} else {\n";
1063 if (direction == 0) {
1065 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
1068 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (seq - 1) * "
1072 out <<
SP <<
SP <<
SP <<
"size_t previous_offset = (index + 1) * "
1075 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1076 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_state[i + offset] += "
1077 << OpName <<
"_forget_gate[i + offset] * " << OpName <<
"_cell_state[i + previous_offset];\n";
1078 out <<
SP <<
SP <<
SP <<
"}\n";
1079 out <<
SP <<
SP <<
"}\n";
1084 if (direction == 0) {
1086 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1087 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i + offset] += tensor_"
1088 <<
fNP <<
"[i + " << p_offset <<
"] * " << OpName <<
"_cell_state[i + offset];\n";
1089 out <<
SP <<
SP <<
SP <<
"}\n";
1092 out <<
SP <<
SP <<
SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1093 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i + offset] += tensor_"
1094 <<
fNP <<
"[i + " << p_offset <<
"] * " << OpName <<
"_cell_state[i + offset];\n";
1095 out <<
SP <<
SP <<
SP <<
"}\n";
1101 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1102 if (
fType ==
"float") {
1103 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_output_gate[i] > " << -
fAttrClip
1104 <<
") ? " << OpName <<
"_output_gate[i] : " << -
fAttrClip <<
";\n";
1106 out <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = (x < " <<
fAttrClip <<
") ? x : "
1108 out <<
SP <<
SP <<
"}\n";
1112 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1113 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_output_gate[i] < 0.)\n";
1114 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = 0.;\n";
1115 out <<
SP <<
SP <<
"}\n";
1117 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1118 if (
fType ==
"float") {
1119 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_output_gate[i]);\n";
1121 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = (1. - ex) / (1. + ex);\n";
1122 out <<
SP <<
SP <<
"}\n";
1124 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1125 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = 1. / (1. + exp(-" << OpName
1126 <<
"_output_gate[i]));\n";
1127 out <<
SP <<
SP <<
"}\n";
1129 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1130 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = "
1133 out <<
SP <<
SP <<
"}\n";
1135 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1136 if (
fType ==
"float") {
1138 <<
" * "<< OpName <<
"_output_gate[i]);\n";
1140 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = "
1142 out <<
SP <<
SP <<
"}\n";
1144 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1145 if (
fType ==
"float") {
1148 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
1150 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = (b < 1.) ? b : 1.;\n";
1151 out <<
SP <<
SP <<
"}\n";
1153 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1154 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_output_gate[i] < 0.)\n";
1155 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = "
1157 out <<
SP <<
SP <<
"}\n";
1159 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1160 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_output_gate[i] < "
1162 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = 0.;\n";
1163 out <<
SP <<
SP <<
"}";
1165 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1166 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_output_gate[i] < 0.)\n";
1167 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = "
1169 out <<
SP <<
SP <<
"}\n";
1171 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1172 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = " << OpName
1173 <<
"_output_gate[i] / (1. + abs(" << OpName <<
"_output_gate[i]));\n";
1174 out <<
SP <<
SP <<
"}\n";
1176 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1177 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_output_gate[i] = log(1. + exp("
1178 << OpName <<
"_output_gate[i]));\n";
1179 out <<
SP <<
SP <<
"}\n";
1183 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1184 <<
"_cell_state + offset + " <<
size <<
", "<< OpName <<
"_new_cell_state + offset);\n";
1187 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1188 if (
fType ==
"float") {
1189 out <<
SP <<
SP <<
SP <<
"float x = (" << OpName <<
"_new_cell_state[i] > " << -
fAttrClip
1190 <<
") ? " << OpName <<
"_new_cell_state[i] : " << -
fAttrClip <<
";\n";
1192 out <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = (x < " <<
fAttrClip <<
") ? x : "
1194 out <<
SP <<
SP <<
"}\n";
1198 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1199 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_new_cell_state[i] < 0.)\n";
1200 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = 0.;\n";
1201 out <<
SP <<
SP <<
"}\n";
1203 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1204 if (
fType ==
"float") {
1205 out <<
SP <<
SP <<
SP <<
"float ex = exp(-2 * " << OpName <<
"_new_cell_state[i]);\n";
1207 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = (1. - ex) / (1. + ex);\n";
1208 out <<
SP <<
SP <<
"}\n";
1210 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1211 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = 1. / (1. + exp(-" << OpName
1212 <<
"_new_cell_state[i]));\n";
1213 out <<
SP <<
SP <<
"}\n";
1215 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1216 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = "
1219 out <<
SP <<
SP <<
"}\n";
1221 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1222 if (
fType ==
"float") {
1224 <<
" * "<< OpName <<
"_new_cell_state[i]);\n";
1226 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = "
1228 out <<
SP <<
SP <<
"}\n";
1230 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1231 if (
fType ==
"float") {
1234 out <<
SP <<
SP <<
SP <<
"float b = (a > 0.) ? a : 0.;\n";
1236 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = (b < 1.) ? b : 1.;\n";
1237 out <<
SP <<
SP <<
"}\n";
1239 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1240 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_new_cell_state[i] < 0.)\n";
1241 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = "
1243 out <<
SP <<
SP <<
"}\n";
1245 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1246 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_new_cell_state[i] < "
1248 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = 0.;\n";
1249 out <<
SP <<
SP <<
"}";
1251 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1252 out <<
SP <<
SP <<
SP <<
"if (" << OpName <<
"_new_cell_state[i] < 0.)\n";
1253 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = "
1254 <<
fAttrActivationAlpha[direction * 3 + 2] <<
" * exp(" << OpName <<
"_new_cell_state[i] - 1.);\n";
1255 out <<
SP <<
SP <<
"}\n";
1257 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1258 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = " << OpName
1259 <<
"_new_cell_state[i] / (1. + abs(" << OpName <<
"_new_cell_state[i]));\n";
1260 out <<
SP <<
SP <<
"}\n";
1262 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1263 out <<
SP <<
SP <<
SP <<
SP << OpName <<
"_new_cell_state[i] = log(1. + exp("
1264 << OpName <<
"_new_cell_state[i]));\n";
1265 out <<
SP <<
SP <<
"}\n";
1269 out <<
SP <<
SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1270 out <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[i] = " << OpName <<
"_output_gate[i] * "
1271 << OpName <<
"_new_cell_state[i];\n";
1272 out <<
SP <<
SP <<
"}\n";
1278 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
1279 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1281 for (
size_t direction = 0; direction < num_directions; direction++) {
1283 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
SP <<
"size_t idx = seq * "
1286 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
SP << OpName <<
"_cell_state[idx] = 0.;\n";
1287 out <<
SP <<
SP <<
SP <<
SP <<
SP <<
SP << OpName <<
"_hidden_state[idx] = 0.;\n";
1290 out <<
SP <<
SP <<
SP <<
"}\n";
1291 out <<
SP <<
SP <<
"}\n";
1297 if (!
fNY_h.empty()) {
1302 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state, " << OpName <<
"_hidden_state + "
1303 << y_h_size <<
", tensor_" <<
fNY_h <<
");\n";
1306 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + " <<
offset <<
", " << OpName
1307 <<
"_hidden_state + " <<
offset <<
" + " << y_h_size <<
", tensor_" <<
fNY_h <<
");\n";
1309 if (num_directions == 2) {
1310 out <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + " << y_h_size <<
", " << OpName
1311 <<
"_hidden_state + " << 2 * y_h_size <<
", tensor_" <<
fNY_h <<
" + " << y_h_size <<
");\n";
1315 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1317 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1321 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1323 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1326 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1327 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1330 if (num_directions == 2) {
1331 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1336 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1337 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1342 if (!
fNY_c.empty()) {
1347 out <<
SP <<
"std::copy(" << OpName <<
"_cell_state, " << OpName <<
"_hidden_state + "
1348 << y_h_size <<
", tensor_" <<
fNY_c <<
");\n";
1351 out <<
SP <<
"std::copy(" << OpName <<
"_cell_state + " <<
offset <<
", " << OpName
1352 <<
"_cell_state + " <<
offset <<
" + " << y_h_size <<
", tensor_" <<
fNY_c <<
");\n";
1354 if (num_directions == 2) {
1355 out <<
SP <<
"std::copy(" << OpName <<
"_cell_state + " << y_h_size <<
", " << OpName
1356 <<
"_cell_state + " << 2 * y_h_size <<
", tensor_" <<
fNY_c <<
" + " << y_h_size <<
");\n";
1360 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1362 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1366 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1368 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1371 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1372 <<
"_cell_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_c <<
" + y_h_offset);\n";
1375 if (num_directions == 2) {
1376 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1381 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1382 <<
"_cell_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_c <<
" + y_h_offset);\n";
1390 for (
size_t direction = 0; direction < num_directions; direction++) {
1391 out <<
SP <<
"for (size_t seq = 0; seq < " << seq_length <<
"; seq++) {\n";
1392 out <<
SP <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1395 out <<
SP <<
SP <<
SP <<
"size_t y_offset = batch * " << seq_length * num_directions *
fAttrHiddenSize
1397 out <<
SP <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1398 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY <<
" + y_offset);\n";
1399 out <<
SP <<
SP <<
"}\n";
1403 if (!
fNY_h.empty()) {
1406 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1408 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1409 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1410 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1413 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1415 out <<
SP <<
SP <<
"size_t seq = " << seq_length - 1 <<
";\n";
1419 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1421 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1422 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1423 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1426 if (num_directions == 2) {
1427 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1430 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
" + "
1432 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_hidden_state + offset, " << OpName
1433 <<
"_hidden_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_h <<
" + y_h_offset);\n";
1438 if (!
fNY_c.empty()) {
1441 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1443 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1444 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1445 <<
"_cell_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_c <<
" + y_h_offset);\n";
1448 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1450 out <<
SP <<
SP <<
"size_t seq = " << seq_length - 1 <<
";\n";
1454 out <<
SP <<
SP <<
"size_t offset = seq * " << num_directions * batch_size *
fAttrHiddenSize
1456 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
";\n";
1457 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1458 <<
"_cell_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_c <<
" + y_h_offset);\n";
1461 if (num_directions == 2) {
1462 out <<
SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
1465 out <<
SP <<
SP <<
"size_t y_h_offset = batch * " << num_directions *
fAttrHiddenSize <<
" + "
1467 out <<
SP <<
SP <<
"std::copy(" << OpName <<
"_cell_state + offset, " << OpName
1468 <<
"_cell_state + offset + " <<
fAttrHiddenSize <<
", tensor_" <<
fNY_c <<
" + y_h_offset);\n";