16#ifndef TMVA_RSOFIEREADER
17#define TMVA_RSOFIEREADER
34namespace Experimental {
51 RSofieReader(
const std::string &path, std::vector<std::vector<size_t>> inputShape = {},
int verbose = 0)
54 enum EModelType {kONNX, kKeras, kPt, kROOT, kNotDef};
55 EModelType
type = kNotDef;
57 auto pos1 = path.rfind(
"/");
58 auto pos2 = path.find(
".onnx");
59 if (pos2 != std::string::npos) {
62 pos2 = path.find(
".h5");
63 if (pos2 != std::string::npos) {
66 pos2 = path.find(
".pt");
67 if (pos2 != std::string::npos) {
71 pos2 = path.find(
".root");
72 if (pos2 != std::string::npos) {
78 if (
type == kNotDef) {
79 throw std::runtime_error(
"Input file is not an ONNX or Keras or PyTorch file");
81 if (pos1 == std::string::npos)
85 std::string modelName = path.substr(pos1,pos2-pos1);
86 std::string fileType = path.substr(pos2+1, path.length()-pos2-1);
87 if (verbose) std::cout <<
"Parsing SOFIE model " << modelName <<
" of type " << fileType << std::endl;
91 std::string parserCode;
95 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
97 gInterpreter->Declare(
"#include \"TMVA/RModelParser_ONNX.hxx\"");
98 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
100 parserCode +=
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path +
"\",true); \n";
102 parserCode +=
"TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path +
"\"); \n";
104 else if (
type == kKeras) {
107 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with Keras since libPyMVA is missing");
109 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path +
"\"); \n";
111 else if (
type == kPt) {
114 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since libPyMVA is missing");
116 if (inputShape.size() == 0) {
117 throw std::runtime_error(
"RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
119 std::string inputShapeStr =
"{";
120 for (
unsigned int i = 0; i < inputShape.size(); i++) {
121 inputShapeStr +=
"{ ";
122 for (
unsigned int j = 0; j < inputShape[i].size(); j++) {
124 if (j < inputShape[i].
size()-1) inputShapeStr +=
", ";
126 inputShapeStr +=
"}";
127 if (i < inputShape.size()-1) inputShapeStr +=
", ";
129 inputShapeStr +=
"}";
130 parserCode +=
"{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path +
"\", "
131 + inputShapeStr +
"); \n";
133 else if (
type == kROOT) {
135 parserCode +=
"{\nauto fileRead = TFile::Open(\"" + path +
"\",\"READ\");\n";
136 parserCode +=
"TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
137 parserCode +=
"auto keyList = fileRead->GetListOfKeys(); TString name;\n";
138 parserCode +=
"for (const auto&& k : *keyList) { \n";
139 parserCode +=
" TString cname = ((TKey*)k)->GetClassName(); if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
140 parserCode +=
"fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
141 parserCode +=
"TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
145 if (inputShape.size() > 0 && inputShape[0].size() > 0) {
146 batchSize = inputShape[0][0];
147 if (batchSize < 1) batchSize = 1;
149 if (verbose) std::cout <<
"generating the code with batch size = " << batchSize <<
" ...\n";
150 parserCode +=
"model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
153 parserCode +=
"model.PrintGenerated(); \n";
154 parserCode +=
"model.OutputGenerated();\n";
157 parserCode +=
"return 1;\n }\n";
159 if (verbose) std::cout <<
"//ParserCode being executed:\n" << parserCode << std::endl;
161 auto iret =
gROOT->ProcessLine(parserCode.c_str());
163 std::string msg =
"RSofieReader: error processing the parser code: \n" + parserCode;
164 throw std::runtime_error(msg);
168 std::string modelHeader = modelName +
".hxx";
169 if (verbose) std::cout <<
"compile generated code from file " <<modelHeader << std::endl;
171 std::string msg =
"RSofieReader: input header file " + modelHeader +
" is not existing";
172 throw std::runtime_error(msg);
174 if (verbose) std::cout <<
"Creating Inference function for model " << modelName << std::endl;
175 std::string declCode;
176 declCode +=
"#pragma cling optimize(2)\n";
177 declCode +=
"#include \"" + modelHeader +
"\"\n";
179 std::string sessionClassName =
"TMVA_SOFIE_" + modelName +
"::Session";
181 std::string uidName = uuid.
AsString();
182 uidName.erase(std::remove_if(uidName.begin(), uidName.end(),
183 [](
char const&
c ) ->
bool { return !std::isalnum(c); } ), uidName.end());
185 std::string sessionName =
"session_" + uidName;
186 declCode += sessionClassName +
" " + sessionName +
";";
188 if (verbose) std::cout <<
"//global session declaration\n" << declCode << std::endl;
192 std::string msg =
"RSofieReader: error compiling inference code and creating session class\n" + declCode;
193 throw std::runtime_error(msg);
199 std::stringstream ifuncCode;
200 std::string funcName =
"SofieInference_" + uidName;
201 ifuncCode <<
"std::vector<float> " + funcName +
"( void * ptr, float * data) {\n";
202 ifuncCode <<
" " << sessionClassName <<
" * s = " <<
"(" << sessionClassName <<
"*) (ptr);\n";
203 ifuncCode <<
" return s->infer(data);\n";
206 if (verbose) std::cout <<
"//Inference function code using global session instance\n"
207 << ifuncCode.str() << std::endl;
211 std::string msg =
"RSofieReader: error compiling inference function\n" + ifuncCode.str();
212 throw std::runtime_error(msg);
215 fFuncPtr =
reinterpret_cast<std::vector<float> (*)(
void *,
const float *)
>(fptr);
220 std::vector<float>
Compute(
const std::vector<float> &
x)
223 return std::vector<float>();
242 const auto nrows =
x.GetShape()[0];
243 const auto rowsize =
x.GetStrides()[0];
250 for (
size_t i = 1; i < nrows; i++) {
261 std::function<std::vector<float> (
void *,
const float *)>
fFuncPtr;
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
R__EXTERN TSystem * gSystem
#define R__WRITE_LOCKGUARD(mutex)
TMVA::RSofieReader class for reading external Machine Learning models in ONNX files,...
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor The shape of the input tensor should be {nevents,...
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
std::function< std::vector< float >(void *, const float *)> fFuncPtr
RSofieReader(const std::string &path, std::vector< std::vector< size_t > > inputShape={}, int verbose=0)
Create TMVA model from ONNX file print level can be 0 (minimal) 1 with info , 2 with all ONNX parsing...
RTensor is a container with contiguous memory and shape information.
virtual int Load(const char *module, const char *entry="", Bool_t system=kFALSE)
Load a shared library.
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
This class defines a UUID (Universally Unique IDentifier), also known as GUIDs (Globally Unique IDent...
const char * AsString() const
Return UUID as string. Copy string immediately since it will be reused.
std::string ToString(const T &val)
Utility function for conversion to strings.
R__EXTERN TVirtualRWMutex * gCoreMutex
create variable transformations