Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RSofieReader.hxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA * *
4 * *
5 * Description: *
6 * *
7 * Authors: *
8 * Lorenzo Moneta *
9 * *
10 * Copyright (c) 2022: *
11 * CERN, Switzerland *
12 * *
13 **********************************************************************************/
14
15
16#ifndef TMVA_RSOFIEREADER
17#define TMVA_RSOFIEREADER
18
19
20#include <string>
21#include <vector>
22#include <memory> // std::unique_ptr
23#include <sstream> // std::stringstream
24#include <iostream>
25#include "TROOT.h"
26#include "TSystem.h"
27#include "TError.h"
28#include "TInterpreter.h"
29#include "TUUID.h"
30#include "TMVA/RTensor.hxx"
31#include "Math/Util.h"
32
33namespace TMVA {
34namespace Experimental {
35
36
37
38
39/// TMVA::RSofieReader class for reading external Machine Learning models
40/// in ONNX files, Keras .h5 files or PyTorch .pt files
41/// and performing the inference using SOFIE
42/// It is reccomended to use ONNX if possible since there is a larger support for
43/// model operators.
44
46
47
48public:
49 /// Create TMVA model from ONNX file
50 /// print level can be 0 (minimal) 1 with info , 2 with all ONNX parsing info
51 RSofieReader(const std::string &path, std::vector<std::vector<size_t>> inputShapes = {}, int verbose = 0)
52 {
53
54 enum EModelType {kONNX, kKeras, kPt, kROOT, kNotDef}; // type of model
55 EModelType type = kNotDef;
56
57 auto pos1 = path.rfind("/");
58 auto pos2 = path.find(".onnx");
59 if (pos2 != std::string::npos) {
60 type = kONNX;
61 } else {
62 pos2 = path.find(".h5");
63 if (pos2 != std::string::npos) {
64 type = kKeras;
65 } else {
66 pos2 = path.find(".pt");
67 if (pos2 != std::string::npos) {
68 type = kPt;
69 }
70 else {
71 pos2 = path.find(".root");
72 if (pos2 != std::string::npos) {
73 type = kROOT;
74 }
75 }
76 }
77 }
78 if (type == kNotDef) {
79 throw std::runtime_error("Input file is not an ONNX or Keras or PyTorch file");
80 }
81 if (pos1 == std::string::npos)
82 pos1 = 0;
83 else
84 pos1 += 1;
85 std::string modelName = path.substr(pos1,pos2-pos1);
86 std::string fileType = path.substr(pos2+1, path.length()-pos2-1);
87 if (verbose) std::cout << "Parsing SOFIE model " << modelName << " of type " << fileType << std::endl;
88
89 // create code for parsing model and generate C++ code for inference
90 // make it in a separate scope to avoid polluting global interpreter space
91 std::string parserCode;
92 if (type == kONNX) {
93 // check first if we can load the SOFIE parser library
94 if (gSystem->Load("libROOTTMVASofieParser") < 0) {
95 throw std::runtime_error("RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
96 }
97 gInterpreter->Declare("#include \"TMVA/RModelParser_ONNX.hxx\"");
98 parserCode += "{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
99 if (verbose == 2)
100 parserCode += "TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + "\",true); \n";
101 else
102 parserCode += "TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + "\"); \n";
103 }
104 else if (type == kKeras) {
105 // use Keras direct parser
106 if (gSystem->Load("libPyMVA") < 0) {
107 throw std::runtime_error("RSofieReader: cannot use SOFIE with Keras since libPyMVA is missing");
108 }
109 parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path + "\"); \n";
110 }
111 else if (type == kPt) {
112 // use PyTorch direct parser
113 if (gSystem->Load("libPyMVA") < 0) {
114 throw std::runtime_error("RSofieReader: cannot use SOFIE with PyTorch since libPyMVA is missing");
115 }
116 if (inputShapes.size() == 0) {
117 throw std::runtime_error("RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
118 }
119 std::string inputShapesStr = "{";
120 for (unsigned int i = 0; i < inputShapes.size(); i++) {
121 inputShapesStr += "{ ";
122 for (unsigned int j = 0; j < inputShapes[i].size(); j++) {
123 inputShapesStr += ROOT::Math::Util::ToString(inputShapes[i][j]);
124 if (j < inputShapes[i].size()-1) inputShapesStr += ", ";
125 }
126 inputShapesStr += "}";
127 if (i < inputShapes.size()-1) inputShapesStr += ", ";
128 }
129 inputShapesStr += "}";
130 parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path + "\", "
131 + inputShapesStr + "); \n";
132 }
133 else if (type == kROOT) {
134 // use parser from ROOT
135 parserCode += "{\nauto fileRead = TFile::Open(\"" + path + "\",\"READ\");\n";
136 parserCode += "TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
137 parserCode += "auto keyList = fileRead->GetListOfKeys(); TString name;\n";
138 parserCode += "for (const auto&& k : *keyList) { \n";
139 parserCode += " TString cname = ((TKey*)k)->GetClassName(); if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
140 parserCode += "fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
141 parserCode += "TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
142 }
143
144 int batchSize = 1;
145 if (inputShapes.size() > 0 && inputShapes[0].size() > 0) {
146 batchSize = inputShapes[0][0];
147 if (batchSize < 1) batchSize = 1;
148 }
149 if (verbose) std::cout << "generating the code with batch size = " << batchSize << " ...\n";
150 parserCode += "model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
151 + ROOT::Math::Util::ToString(batchSize) + "); \n";
152 if (verbose > 1)
153 parserCode += "model.PrintGenerated(); \n";
154 parserCode += "model.OutputGenerated();\n";
155
156 //end of parsing code, close the scope and return 1 to indicate a success
157 parserCode += "return 1;\n }\n";
158
159 if (verbose) std::cout << "//ParserCode being executed:\n" << parserCode << std::endl;
160
161 auto iret = gROOT->ProcessLine(parserCode.c_str());
162 if (iret != 1) {
163 std::string msg = "RSofieReader: error processing the parser code: \n" + parserCode;
164 throw std::runtime_error(msg);
165 }
166
167 // compile now the generated code and create Session class
168 std::string modelHeader = modelName + ".hxx";
169 if (verbose) std::cout << "compile generated code from file " <<modelHeader << std::endl;
170 if (gSystem->AccessPathName(modelHeader.c_str())) {
171 std::string msg = "RSofieReader: input header file " + modelHeader + " is not existing";
172 throw std::runtime_error(msg);
173 }
174 if (verbose) std::cout << "Creating Inference function for model " << modelName << std::endl;
175 std::string declCode;
176 declCode += "#pragma cling optimize(2)\n";
177 declCode += "#include \"" + modelHeader + "\"\n";
178 // create global session instance: use UUID to have an unique name
179 std::string sessionClassName = "TMVA_SOFIE_" + modelName + "::Session";
180 TUUID uuid;
181 std::string uidName = uuid.AsString();
182 uidName.erase(std::remove_if(uidName.begin(), uidName.end(),
183 []( char const& c ) -> bool { return !std::isalnum(c); } ), uidName.end());
184
185 std::string sessionName = "session_" + uidName;
186 declCode += sessionClassName + " " + sessionName + ";";
187
188 if (verbose) std::cout << "//global session declaration\n" << declCode << std::endl;
189
190 bool ret = gInterpreter->Declare(declCode.c_str());
191 if (!ret) {
192 std::string msg = "RSofieReader: error compiling inference code and creating session class\n" + declCode;
193 throw std::runtime_error(msg);
194 }
195
196 fSessionPtr = (void*) gInterpreter->Calc(sessionName.c_str());
197
198 // define a function to be called for inference
199 std::stringstream ifuncCode;
200 std::string funcName = "SofieInference_" + uidName;
201 ifuncCode << "std::vector<float> " + funcName + "( void * ptr, float * data) {\n";
202 ifuncCode << " " << sessionClassName << " * s = " << "(" << sessionClassName << "*) (ptr);\n";
203 ifuncCode << " return s->infer(data);\n";
204 ifuncCode << "}\n";
205
206 if (verbose) std::cout << "//Inference function code using global session instance\n"
207 << ifuncCode.str() << std::endl;
208
209 ret = gInterpreter->Declare(ifuncCode.str().c_str());
210 if (!ret) {
211 std::string msg = "RSofieReader: error compiling inference function\n" + ifuncCode.str();
212 throw std::runtime_error(msg);
213 }
214 auto fptr = gInterpreter->Calc(funcName.c_str());
215 fFuncPtr = reinterpret_cast<std::vector<float> (*)(void *, const float *)>(fptr);
216 fInitialized = true;
217 }
218
219 /// Compute model prediction on vector
220 std::vector<float> Compute(const std::vector<float> &x)
221 {
222 if(!fInitialized) {
223 return std::vector<float>();
224 }
225
226 // Take lock to protect model evaluation
228
229 // Evaluate TMVA model (need to add support for multiple outputs)
230 auto result = fFuncPtr(fSessionPtr, x.data());
231 return result;
232
233 }
234 /// Compute model prediction on input RTensor
235 /// The shape of the input tensor should be {nevents, nfeatures}
236 /// and the return shape will be {nevents, noutputs}
238 {
239 if(!fInitialized) {
240 return RTensor<float>({0});
241 }
242 const auto nrows = x.GetShape()[0];
243 const auto rowsize = x.GetStrides()[0];
244 auto result = fFuncPtr(fSessionPtr, x.GetData());
245
246 RTensor<float> y({nrows, result.size()}, MemoryLayout::ColumnMajor);
247 std::copy(result.begin(),result.end(), y.GetData());
248 //const bool layout = x.GetMemoryLayout() == MemoryLayout::ColumnMajor ? false : true;
249 // assume column major layout
250 for (size_t i = 1; i < nrows; i++) {
251 result = fFuncPtr(fSessionPtr, x.GetData() + i*rowsize);
252 std::copy(result.begin(),result.end(), y.GetData() + i*result.size());
253 }
254 return y;
255 }
256
257private:
258
259 bool fInitialized = false;
260 void * fSessionPtr = nullptr;
261 std::function<std::vector<float> (void *, const float *)> fFuncPtr;
262
263};
264
265} // namespace Experimental
266} // namespace TMVA
267
268#endif // TMVA_RREADER
#define c(i)
Definition RSha256.hxx:101
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
#define gInterpreter
#define gROOT
Definition TROOT.h:405
R__EXTERN TSystem * gSystem
Definition TSystem.h:560
#define R__WRITE_LOCKGUARD(mutex)
TMVA::RSofieReader class for reading external Machine Learning models in ONNX files,...
RSofieReader(const std::string &path, std::vector< std::vector< size_t > > inputShapes={}, int verbose=0)
Create TMVA model from ONNX file print level can be 0 (minimal) 1 with info , 2 with all ONNX parsing...
RTensor< float > Compute(RTensor< float > &x)
Compute model prediction on input RTensor The shape of the input tensor should be {nevents,...
std::vector< float > Compute(const std::vector< float > &x)
Compute model prediction on vector.
std::function< std::vector< float >(void *, const float *)> fFuncPtr
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
const Shape_t & GetShape() const
Definition RTensor.hxx:242
virtual int Load(const char *module, const char *entry="", Bool_t system=kFALSE)
Load a shared library.
Definition TSystem.cxx:1858
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1299
This class defines a UUID (Universally Unique IDentifier), also known as GUIDs (Globally Unique IDent...
Definition TUUID.h:42
const char * AsString() const
Return UUID as string. Copy string immediately since it will be reused.
Definition TUUID.cxx:571
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
std::string ToString(const T &val)
Utility function for conversion to strings.
Definition Util.h:50
R__EXTERN TVirtualRWMutex * gCoreMutex
create variable transformations