40 fUseSession = model.UseSession();
42 if (!model.CheckIfTensorAlreadyExist(fNX)) {
43 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " + fNX +
44 " is not found in model.");
46 fShapeX = model.GetTensorShape(fNX);
47 if (fShapeX.size() != 3) {
48 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " + fNX +
49 " is not of 3 dimensions.");
51 if (!model.CheckIfTensorAlreadyExist(fNW)) {
52 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " + fNW +
53 " is not found in model.");
55 fShapeW = model.GetTensorShape(fNW);
56 if (fShapeW.size() != 3) {
57 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " + fNW +
58 " is not of 3 dimensions.");
60 if (!model.CheckIfTensorAlreadyExist(fNR)) {
61 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " + fNR +
62 " is not found in model.");
64 fShapeR = model.GetTensorShape(fNR);
65 if (fShapeR.size() != 3) {
66 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " + fNR +
67 " is not of 3 dimensions.");
70 if (!model.CheckIfTensorAlreadyExist(fNB)) {
71 throw std::runtime_error(
"TMVA SOFIE RNN op input tensor " + fNB +
72 " is not found in model.");
74 fShapeB = model.GetTensorShape(fNB);
75 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
76 throw std::runtime_error(
"TMVA SOFIE RNN op input tensor " + fNB +
77 " is not of 2 or 4 dimensions.");
79 if (fShapeB.size() == 2) {
83 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
84 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
85 if (fType ==
"float") {
88 batch_size * fAttrHiddenSize];
89 std::vector<float>
sum(fAttrHiddenSize);
92 for (
size_t h = 0;
h < fAttrHiddenSize;
h++) {
100 seq * batch_size * fAttrHiddenSize +
batch * fAttrHiddenSize;
106 batch_size, fAttrHiddenSize};
108 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB),
110 fShapeB = model.GetTensorShape(fNB);
114 if (!fNSequence_lens.empty()) {
115 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
116 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " +
117 fNSequence_lens +
"is not found in model.");
119 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
120 if (fShapeSequence_lens.size() != 1) {
121 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " +
122 fNSequence_lens +
" is not of 1 dimension.");
125 if (!fNInitial_h.empty()) {
126 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
127 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " +
128 fNInitial_h +
" is not found in model.");
130 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
131 if (fShapeInitial_h.size() != 3) {
132 throw std::runtime_error(
"TMVA SOFIE RNN Op input tensor " +
133 fNInitial_h +
" is not of 3 dimensions.");
137 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
138 if (!model.CheckIfTensorAlreadyExist(fNY)) {
139 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
142 if (!fNY_h.empty()) {
143 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
144 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
145 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX),
157 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
161 if (fAttrDirection !=
"forward" && fAttrDirection !=
"backward" &&
162 fAttrDirection !=
"bidirectional") {
163 throw std::runtime_error(
164 "TMVA SOFIE - Invalid RNN direction fAttrDirection = " +
167 if (fAttrHiddenSize != fShapeW[1]) {
168 throw std::runtime_error(
169 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
170 std::to_string(fShapeW[1]));
172 if (fAttrLayout > 1) {
173 throw std::runtime_error(
174 "TMVA SOFIE - Layout fAttrLayout = " + std::to_string(fAttrLayout) +
175 " must be 0 (timewise) or 1 (batchwise)");
177 if (fAttrActivations.empty()) {
178 if (fAttrDirection ==
"bidirectional") {
179 fAttrActivations = {
"Tanh",
"Tanh"};
181 fAttrActivations = {
"Tanh"};
185 model.AddNeededStdLib(
"cmath");
224 std::stringstream out;
226 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
227 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
232 if (fAttrLayout == 0) {
233 if (fType ==
"float") {
234 out << SP <<
"float *" <<
OpName <<
"_input = tensor_" << fNX <<
";\n";
238 out << SP << fType <<
" * " <<
OpName <<
"_input = fVec_" <<
OpName <<
"_input.data();\n";
241 out << SP <<
"for(size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
242 out << SP << SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
243 out << SP << SP << SP <<
"for(size_t i = 0; i < " <<
input_size <<
"; i++) {\n";
244 out << SP << SP << SP << SP <<
OpName <<
"_input[seq * " << batch_size *
input_size
245 <<
" + batch * " <<
input_size <<
" + i] = " <<
"tensor_" << fNX <<
"[batch * "
247 out << SP << SP << SP <<
"}\n";
248 out << SP << SP <<
"}\n";
253 if (!fNInitial_h.empty()) {
254 if (fAttrLayout == 0) {
255 out << SP << fType <<
" *" <<
OpName <<
"_initial_hidden_state = " <<
" tensor_"
256 << fNInitial_h <<
";\n";
259 out << SP << fType <<
" * " <<
OpName <<
"_initial_hidden_state = fVec_" <<
OpName
260 <<
"_initial_hidden_state.data();\n";
263 fAttrHiddenSize <<
"] = {0};\n";
266 out << SP <<
"for(size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
267 out << SP << SP <<
"for(size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
268 out << SP << SP << SP <<
OpName <<
"_initial_hidden_state["
269 <<
direction * batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize
270 <<
" + h] = tensor_" << fNInitial_h <<
"[batch * " <<
num_directions * fAttrHiddenSize
271 <<
" + " <<
direction * fAttrHiddenSize <<
" + h];\n";
272 out << SP << SP <<
"}\n";
279 out << SP << fType <<
" * " <<
OpName <<
"_feedforward = fVec_" <<
OpName
280 <<
"_feedforward.data();\n";
282 out << SP << fType <<
" " <<
OpName <<
"_feedforward[" <<
seq_length * batch_size * fAttrHiddenSize <<
"] = {0};\n";
285 if (fAttrLayout == 0 && !fNY.empty()) {
286 out << SP << fType <<
" *" <<
OpName <<
"_hidden_state = tensor_" << fNY <<
";\n";
289 out << SP << fType <<
" * " <<
OpName <<
"_hidden_state = fVec_" <<
OpName <<
"_hidden_state.data();\n";
292 batch_size * fAttrHiddenSize <<
"] = {0};\n";
295 out << SP <<
"char " <<
OpName <<
"_transA = 'N';\n";
296 out << SP <<
"char " <<
OpName <<
"_transB = 'T';\n";
297 out << SP <<
"int " <<
OpName <<
"_m = " <<
seq_length * batch_size <<
";\n";
298 out << SP <<
"int " <<
OpName <<
"_n = " << fAttrHiddenSize <<
";\n";
300 if (fType ==
"float") {
301 out << SP <<
"float " <<
OpName <<
"_alpha = 1.;\n";
302 out << SP <<
"float " <<
OpName <<
"_beta = .0;\n";
305 out << SP <<
"int " <<
OpName <<
"_bias_size = " <<
seq_length * batch_size * fAttrHiddenSize <<
";\n";
306 out << SP <<
"int " <<
OpName <<
"_incx = 1;\n";
307 out << SP <<
"int " <<
OpName <<
"_incy = 1;\n";
312 if (fType ==
"float") {
314 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
316 <<
"_alpha, tensor_" << fNW <<
", &" <<
OpName <<
"_k, " <<
OpName
318 <<
"_feedforward, &" <<
OpName <<
"_n);\n";
320 out << SP <<
"size_t " <<
OpName <<
"_w_offset = " << fAttrHiddenSize *
input_size
322 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
324 <<
"_alpha, tensor_" << fNW <<
" + " <<
OpName <<
"_w_offset, &" <<
OpName
331 if (fType ==
"float") {
333 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
334 << fNB <<
", &" <<
OpName <<
"_incx, " <<
OpName <<
"_feedforward, &" <<
OpName <<
"_incy);\n";
336 out << SP <<
"size_t " <<
OpName <<
"_bias_offset = "
337 <<
seq_length * batch_size * fAttrHiddenSize <<
";\n";
338 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
340 <<
"_feedforward, &" <<
OpName <<
"_incy);\n";
346 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
347 out << SP << SP <<
"size_t offset = seq * " << batch_size * fAttrHiddenSize <<
";\n";
348 out << SP << SP <<
"size_t size = " << batch_size * fAttrHiddenSize <<
";\n";
349 out << SP << SP <<
"size_t h_offset = seq * "
351 <<
direction * batch_size * fAttrHiddenSize <<
";\n";
352 out << SP << SP <<
"std::copy(" <<
OpName <<
"_feedforward + offset, " <<
OpName
353 <<
"_feedforward + offset + size, " <<
OpName <<
"_hidden_state + h_offset);\n";
357 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
358 if (fAttrDirection ==
"backward" ||
direction == 1) {
359 out << SP << SP <<
"size_t index = " <<
seq_length - 1 <<
" - seq;\n";
361 out << SP << SP <<
"size_t index = seq;\n";
364 out << SP << SP <<
"int m2 = " << batch_size <<
";\n";
365 out << SP << SP <<
"size_t offset = index * "
367 <<
direction * batch_size * fAttrHiddenSize <<
";\n";
368 out << SP << SP <<
"size_t size = " << batch_size * fAttrHiddenSize <<
";\n";
369 out << SP << SP <<
"if (seq == 0) {\n";
370 if (!fNInitial_h.empty()) {
372 out << SP << SP << SP <<
"size_t r_offset = "
373 <<
direction * fAttrHiddenSize * fAttrHiddenSize <<
";\n";
374 out << SP << SP << SP <<
"size_t initial_hidden_state_offset = "
375 <<
direction * batch_size * fAttrHiddenSize <<
";\n";
376 if (fType ==
"float") {
377 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName
379 <<
"_alpha, tensor_" << fNR <<
" + r_offset, &" <<
OpName <<
"_n, " <<
OpName
380 <<
"_initial_hidden_state + initial_hidden_state_offset, &" <<
OpName <<
"_n, &"
381 <<
OpName <<
"_alpha, " <<
OpName <<
"_hidden_state + offset, &" <<
OpName <<
"_n);\n";
384 out << SP << SP <<
"} else {\n";
386 out << SP << SP << SP <<
"size_t r_offset = "
387 <<
direction * fAttrHiddenSize * fAttrHiddenSize <<
";\n";
388 if (fAttrDirection ==
"backward" ||
direction == 1) {
389 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
391 <<
" + " <<
direction * batch_size * fAttrHiddenSize <<
";\n";
393 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
395 <<
" + " <<
direction * batch_size * fAttrHiddenSize <<
";\n";
397 if (fType ==
"float") {
398 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
400 <<
" + r_offset, &" <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &"
404 out << SP << SP <<
"}\n";
407 if (fAttrClip > .0) {
408 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
409 if (fType ==
"float") {
410 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_hidden_state[i] > " << -fAttrClip
411 <<
") ? " <<
OpName <<
"_hidden_state[i] : " << -fAttrClip <<
";\n";
413 out << SP << SP << SP <<
OpName <<
"_hidden_state[i] = (x < " << fAttrClip
414 <<
") ? x : " << fAttrClip <<
";\n";
415 out << SP << SP <<
"}\n";
419 if (fAttrActivations[
direction] ==
"Relu") {
420 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
421 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_state[i] < 0.)\n";
422 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = 0.;\n";
423 out << SP << SP <<
"}\n";
424 }
else if (fAttrActivations[
direction] ==
"Tanh") {
425 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
426 if (fType ==
"float") {
427 out << SP << SP << SP <<
"float ex = std::exp(-2 * " <<
OpName <<
"_hidden_state[i]);\n";
429 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = (1. - ex) / (1. + ex);\n";
430 out << SP << SP <<
"}\n";
431 }
else if (fAttrActivations[
direction] ==
"Sigmoid") {
432 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
433 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = 1. / (1. + std::exp(-" <<
OpName
434 <<
"_hidden_state[i]));\n";
435 out << SP << SP <<
"}\n";
436 }
else if (fAttrActivations[
direction] ==
"Affine") {
437 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
438 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = " << fAttrActivationAlpha[
direction]
439 <<
" * " <<
OpName <<
"_hidden_state[i] + " << fAttrActivationBeta[
direction] <<
";\n";
440 out << SP << SP <<
"}\n";
441 }
else if (fAttrActivations[
direction] ==
"ScaledTanh") {
442 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
443 if (fType ==
"float") {
444 out << SP << SP << SP <<
"float ex = std::exp(-2 * " << fAttrActivationBeta[
direction]
445 <<
" * "<<
OpName <<
"_hidden_state[i]);\n";
447 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = " << fAttrActivationAlpha[
direction]
448 <<
" * (1. - ex) / (1. + ex);\n";
449 out << SP << SP <<
"}\n";
450 }
else if (fAttrActivations[
direction] ==
"HardSigmoid") {
451 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
452 if (fType ==
"float") {
453 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction] <<
" * "
454 <<
OpName <<
"_hidden_state[i] + " << fAttrActivationBeta[
direction] <<
";\n";
455 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
457 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = (b < 1.) ? b : 1.;\n";
458 out << SP << SP <<
"}\n";
459 }
else if (fAttrActivations[
direction] ==
"LeakyRelu") {
460 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
461 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_state[i] < 0.)\n";
462 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = " << fAttrActivationAlpha[
direction]
463 <<
" * " <<
OpName <<
"_hidden_state[i];\n";
464 out << SP << SP <<
"}\n";
465 }
else if (fAttrActivations[
direction] ==
"ThresholdRelu") {
466 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
467 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_state[i] < "
468 << fAttrActivationAlpha[
direction] <<
")\n";
469 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = 0.;\n";
470 out << SP << SP <<
"}";
471 }
else if (fAttrActivations[
direction] ==
"Elu") {
472 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
473 out << SP << SP << SP <<
"if (" <<
OpName <<
"_hidden_state[i] < 0.)\n";
474 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = " << fAttrActivationAlpha[
direction]
475 <<
" * std::exp(" <<
OpName <<
"_hidden_state[i] - 1.);\n";
476 out << SP << SP <<
"}\n";
477 }
else if (fAttrActivations[
direction] ==
"Softsign") {
478 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
479 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = " <<
OpName
480 <<
"_hidden_state[i] / (1. + abs(" <<
OpName <<
"_hidden_state[i]));\n";
481 out << SP << SP <<
"}\n";
483 out << SP << SP <<
"for (size_t i = offset; i < offset + size; i++) {\n";
484 out << SP << SP << SP << SP <<
OpName <<
"_hidden_state[i] = log(1. + std::exp("
485 <<
OpName <<
"_hidden_state[i]));\n";
486 out << SP << SP <<
"}\n";
493 if (!fNSequence_lens.empty()) {
494 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
495 out << SP << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
496 out << SP << SP << SP <<
"if (seq >= tensor_" << fNSequence_lens <<
"[batch]) {\n";
497 out << SP << SP << SP << SP <<
"for (size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
499 out << SP << SP << SP << SP << SP <<
OpName <<
"_hidden_state[seq * "
501 << fAttrHiddenSize <<
" + h] = 0.;\n";
503 out << SP << SP << SP << SP << SP <<
OpName <<
"_hidden_state[seq * "
505 << fAttrHiddenSize <<
" + h] = 0.;\n";
506 out << SP << SP << SP << SP << SP <<
OpName <<
"_hidden_state[seq * "
507 <<
num_directions * batch_size * fAttrHiddenSize <<
" + " << batch_size * fAttrHiddenSize
508 <<
" + batch * " << fAttrHiddenSize <<
" + h] = 0.;\n";
510 out << SP << SP << SP << SP <<
"}\n";
511 out << SP << SP << SP <<
"}\n";
512 out << SP << SP <<
"}\n";
517 if (fAttrLayout == 0) {
518 if (!fNY_h.empty()) {
519 if (fNSequence_lens.empty()) {
520 size_t yh_size = batch_size * fAttrHiddenSize;
521 if (fAttrDirection ==
"backward") {
522 out << SP <<
"std::copy(" <<
OpName <<
"_hidden_state, " <<
OpName <<
"_hidden_state + "
523 <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
527 <<
"_hidden_state + " <<
offset <<
" + " <<
yh_size <<
", tensor_" << fNY_h <<
");\n";
531 <<
"_hidden_state + " << 2 *
yh_size <<
", tensor_" << fNY_h <<
" + " <<
yh_size <<
");\n";
534 if (fAttrDirection ==
"backward") {
535 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
536 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
537 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
538 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + offset);\n";
541 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
542 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
543 out << SP << SP <<
"size_t offset = seq * " <<
num_directions * batch_size * fAttrHiddenSize
544 <<
" + batch * " << fAttrHiddenSize <<
";\n";
545 out << SP << SP <<
"size_t yh_offset = batch * " << fAttrHiddenSize <<
";\n";
546 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
547 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
551 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
552 out << SP << SP <<
"size_t offset = " << batch_size * fAttrHiddenSize
553 <<
" + batch * " << fAttrHiddenSize <<
";\n";
554 out << SP << SP <<
"size_t yh_offset = " << batch_size * fAttrHiddenSize
555 <<
" + batch * " << fAttrHiddenSize <<
";\n";
556 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
557 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
565 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
566 out << SP << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
567 out << SP << SP << SP <<
"size_t offset = seq * " <<
num_directions * batch_size * fAttrHiddenSize
568 <<
" + " <<
direction * batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize <<
";\n";
571 out << SP << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
572 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY <<
" + y_offset);\n";
573 out << SP << SP <<
"}\n";
577 if (!fNY_h.empty()) {
578 if (fAttrDirection ==
"backward") {
579 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
580 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
581 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
582 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
583 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
586 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
587 if (fNSequence_lens.empty()) {
588 out << SP << SP <<
"size_t seq = " <<
seq_length - 1 <<
";\n";
590 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
592 out << SP << SP <<
"size_t offset = seq * " <<
num_directions * batch_size * fAttrHiddenSize
593 <<
" + batch * " << fAttrHiddenSize <<
";\n";
594 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
595 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
596 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";
600 out << SP <<
"for (size_t batch = 0; batch < " << batch_size <<
"; batch++) {\n";
601 out << SP << SP <<
"size_t offset = " << batch_size * fAttrHiddenSize <<
" + batch * "
602 << fAttrHiddenSize <<
";\n";
603 out << SP << SP <<
"size_t yh_offset = batch * " <<
num_directions * fAttrHiddenSize <<
" + "
604 << fAttrHiddenSize <<
";\n";
605 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
606 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + yh_offset);\n";