1#ifndef TMVA_SOFIE_ROPERATOR_GEMM
2#define TMVA_SOFIE_ROPERATOR_GEMM
57 static_assert(std::is_same_v<T, float>,
58 "TMVA::SOFIE - Unsupported type parsing a Gemm operator");
81 if (
input.size() > 3)
throw std::runtime_error(
"TMVA SOFIE Gemm Op Shape Inference only need 2 or 3 input tensor");
86 throw std::runtime_error(
"TMVA SOFIE Gemm Op Shape Inference only accept input tensor with >=2 dimensions");
91 if (
input.size() == 3){
96 int ioffset =
input[0].size()-2;
98 std::vector<U> s_a(
input[0].begin() + ioffset,
input[0].begin() + ioffset + 2);
99 std::vector<U> s_b(
input[1].begin() + ioffset,
input[1].begin() + ioffset + 2);
102 std::reverse(s_a.begin(), s_a.end());
105 std::reverse(s_b.begin(), s_b.end());
113 for (
size_t i = 0; i <
input[0].size()-2; i++) {
117 if (valueB.
GetVal() ==
"1")
118 s_y.push_back(
input[0][i]);
119 else if (valueA.
GetVal() ==
"1")
120 s_y.push_back(
input[1][i]);
122 throw std::runtime_error(
"TMVA SOFIE Gemm Op - invalid input shapes " + valueA.
GetVal() +
" and "
126 auto & dimNames =
fModel->GetDimShapeNames();
127 auto p1 = std::find(dimNames.begin(), dimNames.end(), valueA.
param);
128 auto p2 = std::find(dimNames.begin(), dimNames.end(), valueB.
param);
129 if (p1 < p2) s_y.push_back(
input[0][i]);
130 else s_y.push_back(
input[1][i]);
133 s_y.push_back(
input[0][i]);
135 s_y.push_back(
input[1][i]);
137 throw std::runtime_error(
"TMVA SOFIE Gemm Op - invalid input shapes " + valueA.
GetVal() +
" and "
141 s_y.push_back(
input[0][i]);
145 s_y.push_back(s_a[0]);
146 s_y.push_back(s_b[1]);
151 std::vector<std::vector<size_t>> ret;
166 throw std::runtime_error(
"TMVA SOFIE Gemm Op Input Tensor " +
fNA +
" or " +
fNB +
" is not found in model");
170 throw std::runtime_error(
"TMVA SOFIE Gemm Op Input Tensor " +
fNC +
" is not found in model");
181 bool prependOne =
false;
196 bool appendOne =
false;
224 bool broadcast_needed =
false;
226 broadcast_needed =
true;
229 broadcast_needed = (
fShapeC != shapeY);
232 if (broadcast_needed) {
238 if ((
r.first & 2) == 2) {
240 }
else if (
r.first == 4) {
253 shapeY.erase(shapeY.begin());
259 shapeY.erase(shapeY.end()-1);
268 std::cout <<
"Gemm (or MatMul) " <<
" ---> " <<
fNY <<
" shape ";
278 std::string
Generate(std::string opName)
override {
279 opName =
"op_" + opName;
284 std::stringstream out;
292 if (dimA != dimB || dimA != dimY || (
fBroadcastBias && dimC != dimY)) {
297 throw std::runtime_error(
"TMVA SOFIE Gemm(MatMul) has invalid shape for inputs or output");
306 std::vector<Dim> sExtraY;
307 for (int64_t i = 0; i < dimY-2; i++) {
312 std::string lengthExtra_C;
313 std::vector<Dim> sExtraC;
315 bool haveExtraC =
false;
318 for (int64_t i = 0; i < dimC-2; i++) {
322 if (lengthExtra_C !=
"1") haveExtraC =
true;
323 }
else if (dimC > 0) {
324 for (int64_t i = 0; i < dimC; i++) {
338 ( haveExtraC && std::stoi(lengthExtra_Y) != std::stoi(lengthExtra_C)))
339 throw std::runtime_error(
"TMVA SOFIE Gemm Op " + opName +
" Bias tensor " +
fNC +
" has not correct size "
344 if (haveExtraC) out <<
SP <<
"assert(" << lengthExtra_Y <<
" == " << lengthExtra_C <<
");\n";
354 std::cout <<
"WARNING: TMVA SOFIE Gemm Op " + opName +
" Bias tensor is not present but beta value in Gemm is not zero - force it to zero\n";
360 bool doStackMul = dimY > 2 && (
fIsDynamic || std::stoi(lengthExtra_Y) > 1);
362 std::string lengthExtra_A;
363 std::string lengthExtra_B;
364 std::string increment_A;
365 std::string increment_B;
378 bool extraA = (doStackMul && lengthExtra_A !=
"1");
379 bool extraB = (doStackMul && lengthExtra_B !=
"1");
382 std::string biasShapeType = opName +
"_biasShapeType";
388 out <<
SP <<
"int " << biasShapeType <<
" = 0;\n";
390 if (sC[0].GetVal() !=
"1" && sC[1].GetVal() != sY[1].GetVal())
391 out <<
SP <<
"if (" << sC[0] <<
" == 1 && " << sC[1] <<
" == " << sY[1] <<
")\n";
392 else if (sC[0].GetVal() ==
"1")
393 out <<
SP <<
"if (" << sC[1] <<
" == " << sY[1] <<
")\n";
394 else if (sC[1].GetVal() == sY[1].GetVal())
395 out <<
SP <<
"if (" << sC[0] <<
" == 1)\n";
397 out <<
SP <<
SP << biasShapeType <<
" = 1;\n";
400 if (sC[1].GetVal() !=
"1" && sC[0].GetVal() != sY[0].GetVal())
401 out <<
SP <<
"else if (" << sC[1] <<
" == 1 && " << sC[0] <<
" == " << sY[0] <<
")\n";
402 else if (sC[1].GetVal() ==
"1")
403 out <<
SP <<
"else if (" << sC[0] <<
" == " << sY[0] <<
")\n";
404 else if (sC[0].GetVal() == sY[0].GetVal())
405 out <<
SP <<
"else if (" << sC[1] <<
" == 1)\n";
407 out <<
SP <<
SP << biasShapeType <<
" = 2;\n";
410 if (sC[0].GetVal() !=
"1" && sC[1].GetVal() !=
"1")
411 out <<
SP <<
"else if (" << sC[0] <<
" == 1 && " << sC[1] <<
" == 1 )\n";
412 else if (sC[0].GetVal() ==
"1")
413 out <<
SP <<
"else if (" << sC[1] <<
" == 1)\n";
414 else if (sC[1].GetVal() ==
"1")
415 out <<
SP <<
"else if (" << sC[0] <<
" == 1)\n";
416 out <<
SP <<
SP << biasShapeType <<
" = 3;\n";
417 out <<
SP <<
"else\n";
418 out <<
SP <<
SP <<
"throw std::runtime_error(\"TMVA SOFIE Gemm Op - bias tensor "
424 out <<
SP <<
"size_t " << opName <<
"_y_offset = 0;\n";
426 out <<
SP <<
"size_t " << opName <<
"_A_offset = 0;\n";
428 out <<
SP <<
"size_t " << opName <<
"_B_offset = 0;\n";
430 out <<
SP <<
"size_t " << opName <<
"_C_offset = 0;\n";
431 out <<
SP <<
"for (size_t i = 0; i < " << lengthExtra_Y <<
"; i++){\n";
441 out << SP2 <<
"for (size_t j = 0; j < " << sY[0] <<
"; j++) { \n";
442 out << SP2 <<
SP <<
"size_t y_index = ";
444 out << opName <<
"_y_offset + ";
445 if (sY[1].GetVal() !=
"1")
446 out << sY[1] <<
" * j;\n";
450 std::string prefix = SP2 +
SP +
"TMVA::Experimental::SOFIE::";
452 if (sC.size() != 2) {
454 }
if (sC[0].GetVal() ==
"1" && sC[1].GetVal() == sY[1].GetVal()) {
455 out << prefix <<
"Copy(" <<
target <<
" + y_index, tensor_" <<
fNC <<
", " << sY[1] <<
");\n";
456 }
else if (sC[1].GetVal() ==
"1" && sC[0].GetVal() == sY[0].GetVal()) {
457 out << prefix <<
"Fill(" <<
target <<
" + y_index, tensor_" <<
fNC <<
"[j], " << sY[1] <<
");\n";
458 }
else if (sC[0].GetVal() ==
"1" && sC[1].GetVal() ==
"1") {
460 out << prefix <<
"Fill(" <<
target <<
" + y_index, tensor_" <<
fNC <<
"[0], " << sY[1] <<
");\n";
465 out << SP2 <<
SP <<
"if (" << biasShapeType <<
" == 1)\n";
466 out <<
SP << prefix <<
"Copy(" <<
target <<
" + y_index, tensor_" <<
fNC <<
", " << sY[1] <<
");\n";
467 out << SP2 <<
SP <<
"else if (" << biasShapeType <<
" == 2)\n";
468 out <<
SP << prefix <<
"Fill(" <<
target <<
" + y_index, tensor_" <<
fNC <<
"[j], " << sY[1] <<
");\n";
469 out << SP2 <<
SP <<
"else \n";
470 out <<
SP << prefix <<
"Fill(" <<
target <<
" + y_index, tensor_" <<
fNC <<
"[0], " << sY[1] <<
");\n";
478 if (
fType ==
"float"){
480 out << SP2 <<
"TMVA::Experimental::SOFIE::Gemm_Call(" <<
"tensor_" <<
fNY;
481 if (doStackMul) out <<
" + " << opName <<
"_y_offset";
485 <<
n <<
", " <<
m <<
", " << k <<
", ";
486 out << std::setprecision(std::numeric_limits<float>::max_digits10) <<
fAttrAlpha <<
", tensor_" <<
fNB;
487 if (extraB) out <<
" + " << opName <<
"_B_offset";
488 out <<
", tensor_" <<
fNA;
489 if (extraA) out <<
" + " << opName <<
"_A_offset";
490 out <<
", " << std::setprecision(std::numeric_limits<float>::max_digits10) <<
fAttrBeta <<
",";
493 out <<
"tensor_" <<
fNC;
495 out <<
" + " << opName <<
"_C_offset";
505 out <<
SP <<
SP << opName <<
"_y_offset += " << lengthGemm <<
";\n";
506 if (lengthExtra_A !=
"1")
507 out <<
SP <<
SP << opName <<
"_A_offset += " << increment_A <<
";\n";
508 if (lengthExtra_B !=
"1")
509 out <<
SP <<
SP << opName <<
"_B_offset += " << increment_B <<
";\n";
512 out <<
SP <<
SP << opName <<
"_C_offset += " << lengthGemm <<
";\n";
518 out <<
SP <<
"//--- applying RELU to output\n";
519 std::string tnsr =
"tensor_" +
fNY;
521 out <<
SP <<
"TMVA::Experimental::SOFIE::Relu(" << tnsr <<
", " << tnsr <<
", " << reluSize <<
");\n";
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t target
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
void AddNeededStdLib(std::string libname)
std::vector< size_t > GetTensorShape(const std::string &name) const
bool IsDynamicTensor(const std::string &name) const
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
bool CheckIfTensorAlreadyExist(std::string tensor_name)
void AddDynamicTensor(std::string tensor_name, ETensorType type, std::vector< Dim > shape)
bool IsDimInputTensor(const std::string &name) const
std::vector< Dim > GetDynamicTensorShape(const std::string &name) const
ETensorType GetTensorType(std::string name) const
ROperator_Gemm(float alpha, float beta, int_t transA, int_t transB, std::string nameA, std::string nameB, std::string nameC, std::string nameY, EActivationType activation=EActivationType::UNDEFINED)
std::vector< Dim > DynamicShapeInference(const std::vector< std::vector< Dim > > &input)
std::vector< Dim > fShapeY
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
ROperator_Gemm(float alpha, float beta, int_t transA, int_t transB, std::string nameA, std::string nameB, std::string nameY, EActivationType activation=EActivationType::UNDEFINED)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
std::vector< U > DoShapeInference(const std::vector< std::vector< U > > &input)
std::vector< Dim > fShapeA
std::vector< Dim > fShapeB
std::vector< size_t > fShapeC
std::string Generate(std::string opName) override
void Initialize(RModel &model) override
std::vector< Dim > fDimShapeC
bool fCheckBiasShapeAtRuntime
EActivationType fActivation
std::vector< std::string > GetBlasRoutines() override
std::vector< std::string_view > fInputTensorNames
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
std::vector< size_t > MultidirectionalBroadcastShape(std::vector< std::vector< size_t > >)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::vector< Dim > ConvertShapeToDim(const std::vector< size_t > &shape)
Convert shape from integer format to dynamic one (based on Dim)
std::vector< size_t > ConvertShapeToInt(const std::vector< Dim > &shape)
Convert shape based on Dim to integer format.
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations
std::string GetVal() const