Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_PyTorch.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Sanjiban Sengupta 2021
3
4/**********************************************************************************
5 * Project : TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package : TMVA *
7 * Function: TMVA::Experimental::SOFIE::PyTorch::Parse *
8 * *
9 * Description: *
10 * Parser function for translating PyTorch .pt model to RModel object *
11 * *
12 * Example Usage: *
13 * ~~~ {.cpp} *
14 * using TMVA::Experimental::SOFIE; *
15 * // Building the vector of input tensor shapes *
16 * std::vector<size_t> s1{120,1}; *
17 * std::vector<std::vector<size_t>> inputShape{s1}; *
18 * RModel model = PyTorch::Parse("trained_model_dense.pt",inputShape); *
19 * ~~~ *
20 * *
21 **********************************************************************************/
22
23
25
26#include <Python.h>
27
28#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
29#include <numpy/arrayobject.h>
30
32
33namespace {
34
35// Utility functions (taken from PyMethodBase in PyMVA)
36
37void PyRunString(TString code, PyObject *globalNS, PyObject *localNS)
38{
40 if (!fPyReturn) {
41 std::cout << "\nPython error message:\n";
43 throw std::runtime_error("\nFailed to run python code: " + code);
44 }
45}
46
47const char *PyStringAsString(PyObject *string)
48{
51 return cstring;
52}
53
54std::vector<size_t> GetDataFromList(PyObject *listObject)
55{
56 std::vector<size_t> listVec;
59 }
60 return listVec;
61}
62
63} // namespace
64
65
66namespace INTERNAL{
67
68// For searching and calling specific preparatory function for PyTorch ONNX Graph's node
69std::unique_ptr<ROperator> MakePyTorchNode(PyObject* fNode);
70
71std::unique_ptr<ROperator> MakePyTorchGemm(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Gemm operator
72std::unique_ptr<ROperator> MakePyTorchConv(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Conv operator
73std::unique_ptr<ROperator> MakePyTorchRelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Relu operator
74std::unique_ptr<ROperator> MakePyTorchSelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Selu operator
75std::unique_ptr<ROperator> MakePyTorchSigmoid(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Sigmoid operator
76std::unique_ptr<ROperator> MakePyTorchTranspose(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Transpose operator
77
78// For mapping PyTorch ONNX Graph's Node with the preparatory functions for ROperators
79using PyTorchMethodMap = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(PyObject* fNode)>;
80
82{
83 {"onnx::Gemm", &MakePyTorchGemm},
84 {"onnx::Conv", &MakePyTorchConv},
85 {"onnx::Relu", &MakePyTorchRelu},
86 {"onnx::Selu", &MakePyTorchSelu},
87 {"onnx::Sigmoid", &MakePyTorchSigmoid},
88 {"onnx::Transpose", &MakePyTorchTranspose}
89};
90
91
92//////////////////////////////////////////////////////////////////////////////////
93/// \brief Prepares equivalent ROperator with respect to PyTorch ONNX node.
94///
95/// \param[in] fNode Python PyTorch ONNX Graph node
96/// \return unique pointer to ROperator object
97///
98/// Function searches for the passed PyTorch ONNX Graph node in the map, and calls
99/// the specific preparatory function, subsequently returning the ROperator object.
100///
101/// For developing new preparatory functions for supporting PyTorch ONNX Graph nodes
102/// in future, all one needs is to extract the required properties and attributes
103/// from the fNode dictionary which contains all the information about any PyTorch ONNX
104// Graph node and after any required transformations, these are passed for instantiating
105/// the ROperator object.
106///
107/// The fNode dictionary which holds all the information about a PyTorch ONNX Graph's node has
108/// following structure:-
109///
110/// dict fNode { 'nodeType' : Type of node (operator)
111/// 'nodeAttributes' : Attributes of the node
112/// 'nodeInputs' : List of names of input tensors
113/// 'nodeOutputs' : List of names of output tensors
114/// 'nodeDType' : Data-type of the operator node
115/// }
116///
117std::unique_ptr<ROperator> MakePyTorchNode(PyObject* fNode){
118 std::string fNodeType = PyStringAsString(PyDict_GetItemString(fNode,"nodeType"));
119 auto findNode = mapPyTorchNode.find(fNodeType);
120 if(findNode == mapPyTorchNode.end()){
121 throw std::runtime_error("TMVA::SOFIE - Parsing PyTorch node " +fNodeType+" is not yet supported ");
122 }
123 return (findNode->second)(fNode);
124}
125
126
127//////////////////////////////////////////////////////////////////////////////////
128/// \brief Prepares a ROperator_Gemm object
129///
130/// \param[in] fNode Python PyTorch ONNX Graph node
131/// \return Unique pointer to ROperator object
132///
133/// For PyTorch's Linear layer having Gemm operation in its ONNX graph,
134/// the names of the input tensor, output tensor are extracted, and then
135/// are passed to instantiate a ROperator_Gemm object using the required attributes.
136/// fInputs is a list of tensor names, which includes the names of the input tensor
137/// and the weight tensors.
138std::unique_ptr<ROperator> MakePyTorchGemm(PyObject* fNode){
139 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
140 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
141 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
142 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
143
144 // Extracting the parameters for Gemm Operator
145 std::string fNameA = PyStringAsString(PyList_GetItem(fInputs,0));
146 std::string fNameB = PyStringAsString(PyList_GetItem(fInputs,1));
147 std::string fNameC = PyStringAsString(PyList_GetItem(fInputs,2));
148 std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0));
149 float fAttrAlpha = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"alpha")));
150 float fAttrBeta = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"beta")));
151 int_t fAttrTransA;
152 int_t fAttrTransB;
153
154 if(PyDict_Contains(fAttributes,PyUnicode_FromString("transB"))){
155 fAttrTransB = (int_t)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"transB")));
156 fAttrTransA = !fAttrTransB;
157 }
158 else{
159 fAttrTransA=(int_t)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"transA")));
160 fAttrTransB = !fAttrTransA;
161 }
162
163 std::unique_ptr<ROperator> op;
165 case ETensorType::FLOAT: {
166 op.reset(new ROperator_Gemm<float>(fAttrAlpha, fAttrBeta, fAttrTransA, fAttrTransB, fNameA, fNameB, fNameC, fNameY ));
167 break;
168 }
169 default:
170 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Gemm does not yet support input type " + fNodeDType);
171 }
172 return op;
173}
174
175//////////////////////////////////////////////////////////////////////////////////
176/// \brief Prepares a ROperator_Relu object
177///
178/// \param[in] fNode Python PyTorch ONNX Graph node
179/// \return Unique pointer to ROperator object
180///
181/// For instantiating a ROperator_Relu object, the names of
182/// input & output tensors and the data-type of the Graph node
183/// are extracted.
184std::unique_ptr<ROperator> MakePyTorchRelu(PyObject* fNode){
185 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
186 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
187 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
188 std::string fNameX = PyStringAsString(PyList_GetItem(fInputs,0));
189 std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0));
190 std::unique_ptr<ROperator> op;
192 case ETensorType::FLOAT: {
194 break;
195 }
196 default:
197 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Relu does not yet support input type " + fNodeDType);
198 }
199 return op;
200}
201
202//////////////////////////////////////////////////////////////////////////////////
203/// \brief Prepares a ROperator_Selu object
204///
205/// \param[in] fNode Python PyTorch ONNX Graph node
206/// \return Unique pointer to ROperator object
207///
208/// For instantiating a ROperator_Selu object, the names of
209/// input & output tensors and the data-type of the Graph node
210/// are extracted.
211std::unique_ptr<ROperator> MakePyTorchSelu(PyObject* fNode){
212 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
213 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
214 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
215
216 std::unique_ptr<ROperator> op;
218 case ETensorType::FLOAT: {
219 op.reset(new ROperator_Selu<float>(PyStringAsString(PyList_GetItem(fInputs,0)), PyStringAsString(PyList_GetItem(fOutputs,0))));
220 break;
221 }
222 default:
223 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Selu does not yet support input type " + fNodeDType);
224 }
225 return op;
226}
227
228//////////////////////////////////////////////////////////////////////////////////
229/// \brief Prepares a ROperator_Sigmoid object
230///
231/// \param[in] fNode Python PyTorch ONNX Graph node
232/// \return Unique pointer to ROperator object
233///
234/// For instantiating a ROperator_Sigmoid object, the names of
235/// input & output tensors and the data-type of the Graph node
236/// are extracted.
237std::unique_ptr<ROperator> MakePyTorchSigmoid(PyObject* fNode){
238 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
239 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
240 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
241
242 std::unique_ptr<ROperator> op;
244 case ETensorType::FLOAT: {
245 op.reset(new ROperator_Sigmoid<float>(PyStringAsString(PyList_GetItem(fInputs,0)), PyStringAsString(PyList_GetItem(fOutputs,0))));
246 break;
247 }
248 default:
249 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Sigmoid does not yet support input type " + fNodeDType);
250 }
251 return op;
252}
253
254
255//////////////////////////////////////////////////////////////////////////////////
256/// \brief Prepares a ROperator_Transpose object
257///
258/// \param[in] fNode Python PyTorch ONNX Graph node
259/// \return Unique pointer to ROperator object
260///
261/// For Transpose Operator of PyTorch's ONNX Graph, the permute dimensions are found,
262/// and are passed in instantiating the ROperator object.
263std::unique_ptr<ROperator> MakePyTorchTranspose(PyObject* fNode){
264 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
265 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
266 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
267 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
268
269 // Extracting the Permute dimensions for transpose
270 std::vector<int_t> fAttrPermute;
271 PyObject* fPermute=PyDict_GetItemString(fAttributes,"perm");
274 }
275 std::string fNameData = PyStringAsString(PyList_GetItem(fInputs,0));
276 std::string fNameOutput = PyStringAsString(PyList_GetItem(fOutputs,0));
277
278 std::unique_ptr<ROperator> op = std::make_unique<ROperator_Transpose>(fAttrPermute, fNameData, fNameOutput);
279 return op;
280}
281
282
283//////////////////////////////////////////////////////////////////////////////////
284/// \brief Prepares a ROperator_Conv object
285///
286/// \param[in] fNode Python PyTorch ONNX Graph node
287/// \return Unique pointer to ROperator object
288///
289/// For Conv Operator of PyTorch's ONNX Graph, attributes like dilations, group,
290/// kernel shape, pads and strides are found, and are passed in instantiating the
291/// ROperator object with autopad default to `NOTSET`.
292std::unique_ptr<ROperator> MakePyTorchConv(PyObject* fNode){
293 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
294 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
295 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
296 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
297
298 // Extracting the Conv Node Attributes
299 PyObject* fDilations = PyDict_GetItemString(fAttributes,"dilations");
300 PyObject* fGroup = PyDict_GetItemString(fAttributes,"group");
301 PyObject* fKernelShape = PyDict_GetItemString(fAttributes,"kernel_shape");
302 PyObject* fPads = PyDict_GetItemString(fAttributes,"pads");
303 PyObject* fStrides = PyDict_GetItemString(fAttributes,"strides");
304
305 std::string fAttrAutopad = "NOTSET";
306 std::vector<size_t> fAttrDilations = GetDataFromList(fDilations);
307 size_t fAttrGroup = PyLong_AsLong(fGroup);
308 std::vector<size_t> fAttrKernelShape = GetDataFromList(fKernelShape);
309 std::vector<size_t> fAttrPads = GetDataFromList(fPads);
310 std::vector<size_t> fAttrStrides = GetDataFromList(fStrides);
311 std::string nameX = PyStringAsString(PyList_GetItem(fInputs,0));
312 std::string nameW = PyStringAsString(PyList_GetItem(fInputs,1));
313 std::string nameB = PyStringAsString(PyList_GetItem(fInputs,2));
314 std::string nameY = PyStringAsString(PyList_GetItem(fOutputs,0));
315
316 std::unique_ptr<ROperator> op;
318 case ETensorType::FLOAT: {
319 op.reset(new ROperator_Conv<float>(fAttrAutopad, fAttrDilations, fAttrGroup, fAttrKernelShape, fAttrPads, fAttrStrides, nameX, nameW, nameB, nameY));
320 break;
321 }
322 default:
323 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Conv does not yet support input type " + fNodeDType);
324 }
325 return op;
326}
327}//INTERNAL
328
329
330//////////////////////////////////////////////////////////////////////////////////
331/// \param[in] filename file location of PyTorch .pt model
332/// \param[in] inputShapes vector of input shape vectors
333/// \param[in] inputDTypes vector of ETensorType for data-types of Input tensors
334/// \return Parsed RModel object
335///
336/// The `Parse()` function defined in `TMVA::Experimental::SOFIE::PyTorch` will
337/// parse a trained PyTorch .pt model into a RModel Object. The parser uses
338/// internal functions of PyTorch to convert any PyTorch model into its
339/// equivalent ONNX Graph. For this conversion, dummy inputs are built which are
340/// passed through the model and the applied operators are recorded for populating
341/// the ONNX graph. The `Parse()` function requires the shapes and data-types of
342/// the input tensors which are used for building the dummy inputs.
343/// After the said conversion, the nodes of the ONNX graph are then traversed to
344/// extract properties like Node type, Attributes, input & output tensor names.
345/// Function `AddOperator()` is then called on the extracted nodes to add the
346/// operator into the RModel object. The nodes are also checked for adding any
347/// required routines for executing the generated Inference code.
348///
349/// The internal function used to convert the model to graph object returns a list
350/// which contains a Graph object and a dictionary of weights. This dictionary is
351/// used to extract the Initialized tensors for the model. The names and data-types
352/// of the Initialized tensors are extracted along with their values in NumPy array,
353/// and after approapriate type-conversions, they are added into the RModel object.
354///
355/// For adding the Input tensor infos, the names of the input tensors are extracted
356/// from the PyTorch ONNX graph object. The vector of shapes & data-types passed
357/// into the `Parse()` function are used to extract the data-type and the shape
358/// of the input tensors. Extracted input tensor infos are then added into the
359/// RModel object by calling the `AddInputTensorInfo()` function.
360///
361/// For the output tensor infos, names of the output tensors are also extracted
362/// from the Graph object and are then added into the RModel object by calling the
363/// AddOutputTensorNameList() function.
364///
365/// Example Usage:
366/// ~~~ {.cpp}
367/// using TMVA::Experimental::SOFIE;
368/// //Building the vector of input tensor shapes
369/// std::vector<size_t> s1{120,1};
370/// std::vector<std::vector<size_t>> inputShape{s1};
371/// RModel model = PyTorch::Parse("trained_model_dense.pt",inputShape);
372/// ~~~
373RModel Parse(std::string filename, std::vector<std::vector<size_t>> inputShapes, std::vector<ETensorType> inputDTypes){
374
375 char sep = '/';
376 #ifdef _WIN32
377 sep = '\\';
378 #endif
379
380 size_t isep = filename.rfind(sep, filename.length());
381 std::string filename_nodir = filename;
382 if (isep != std::string::npos){
383 filename_nodir = (filename.substr(isep+1, filename.length() - isep));
384 }
385
386 //Check on whether the PyTorch .pt file exists
387 if(!std::ifstream(filename).good()){
388 throw std::runtime_error("Model file "+filename_nodir+" not found!");
389 }
390
391
392 std::time_t ttime = std::time(0);
393 std::tm* gmt_time = std::gmtime(&ttime);
394 std::string parsetime (std::asctime(gmt_time));
395
397
398 //Intializing Python Interpreter and scope dictionaries
400 PyObject* main = PyImport_AddModule("__main__");
401 PyObject* fGlobalNS = PyModule_GetDict(main);
402 PyObject* fLocalNS = PyDict_New();
403 if (!fGlobalNS) {
404 throw std::runtime_error("Can't init global namespace for Python");
405 }
406 if (!fLocalNS) {
407 throw std::runtime_error("Can't init local namespace for Python");
408 }
409
410
411 //Extracting model information
412 //Model is converted to ONNX graph format
413 //using PyTorch's internal function with the input shape provided
414 PyRunString("import torch",fGlobalNS,fLocalNS);
415 PyRunString("print('Torch Version: '+torch.__version__)",fGlobalNS,fLocalNS);
416 PyRunString("from torch.onnx.utils import _model_to_graph",fGlobalNS,fLocalNS);
417 //PyRunString("from torch.onnx.symbolic_helper import _set_onnx_shape_inference",fGlobalNS,fLocalNS);
418 PyRunString(TString::Format("model= torch.jit.load('%s')",filename.c_str()),fGlobalNS,fLocalNS);
419 PyRunString("globals().update(locals())",fGlobalNS,fLocalNS);
420 PyRunString("model.cpu()",fGlobalNS,fLocalNS);
421 PyRunString("model.eval()",fGlobalNS,fLocalNS);
422
423 //Building dummy inputs for the model
424 PyRunString("dummyInputs=[]",fGlobalNS,fLocalNS);
425 for(long unsigned int it=0;it<inputShapes.size();++it){
426 PyRunString("inputShape=[]",fGlobalNS,fLocalNS);
427 for(long unsigned int itr=0;itr<inputShapes[it].size();++itr){
428 PyRunString(TString::Format("inputShape.append(%d)",(int)inputShapes[it][itr]),fGlobalNS,fLocalNS);
429 }
430 PyRunString("dummyInputs.append(torch.rand(*inputShape))",fGlobalNS,fLocalNS);
431 }
432
433
434 //Getting the ONNX graph from model using the dummy inputs and example outputs
435 //PyRunString("_set_onnx_shape_inference(True)",fGlobalNS,fLocalNS);
436 PyRunString("graph=_model_to_graph(model,dummyInputs)",fGlobalNS,fLocalNS);
437
438
439 //Extracting the model information in list modelData
440 PyRunString("modelData=[]",fGlobalNS,fLocalNS);
441 // The '_node_get' helper function is used to avoid dependency on onnx submodule
442 // (for the subscript operator of torch._C.Node), as done in https://github.com/pytorch/pytorch/pull/82628
443 PyRunString("def _node_get(node, key):\n"
444 " sel = node.kindOf(key)\n"
445 " return getattr(node, sel)(key)\n",
446 fGlobalNS, fLocalNS);
447 PyRunString("for i in graph[0].nodes():\n"
448 " globals().update(locals())\n"
449 " nodeData={}\n"
450 " nodeData['nodeType']=i.kind()\n"
451 " nodeAttributeNames=[x for x in i.attributeNames()]\n"
452 " nodeAttributes={j: _node_get(i, j) for j in nodeAttributeNames}\n"
453 " nodeData['nodeAttributes']=nodeAttributes\n"
454 " nodeInputs=[x for x in i.inputs()]\n"
455 " nodeInputNames=[x.debugName() for x in nodeInputs]\n"
456 " nodeData['nodeInputs']=nodeInputNames\n"
457 " nodeOutputs=[x for x in i.outputs()]\n"
458 " nodeOutputNames=[x.debugName() for x in nodeOutputs]\n"
459 " nodeData['nodeOutputs']=nodeOutputNames\n"
460 " nodeDType=[x.type().scalarType() for x in nodeOutputs]\n"
461 " nodeData['nodeDType']=nodeDType\n"
462 " modelData.append(nodeData)",
463 fGlobalNS, fLocalNS);
464
465 PyObject* fPModel = PyDict_GetItemString(fLocalNS,"modelData");
467 PyObject *fNode;
468 std::string fNodeType;
469
470 //Adding operators into the RModel object
473 fNodeType = PyStringAsString(PyDict_GetItemString(fNode,"nodeType"));
474
475 // Adding required routines for inference code generation
476 if(fNodeType == "onnx::Gemm"){
477 rmodel.AddBlasRoutines({"Gemm", "Gemv"});
478 }
479 else if(fNodeType == "onnx::Selu" || fNodeType == "onnx::Sigmoid"){
480 rmodel.AddNeededStdLib("cmath");
481 }
482 else if (fNodeType == "onnx::Conv") {
483 rmodel.AddBlasRoutines({"Gemm", "Axpy"});
484 }
485 rmodel.AddOperator(INTERNAL::MakePyTorchNode(fNode));
486 }
487
488
489 //Extracting model weights to add the initialized tensors to the RModel
490 PyRunString("weightNames=[k for k in graph[1].keys()]",fGlobalNS,fLocalNS);
491 PyRunString("weights=[v.numpy() for v in graph[1].values()]",fGlobalNS,fLocalNS);
492 PyRunString("weightDTypes=[v.type()[6:-6] for v in graph[1].values()]",fGlobalNS,fLocalNS);
493 PyObject* fPWeightNames = PyDict_GetItemString(fLocalNS,"weightNames");
494 PyObject* fPWeightTensors = PyDict_GetItemString(fLocalNS,"weights");
495 PyObject* fPWeightDTypes = PyDict_GetItemString(fLocalNS,"weightDTypes");
497 std::string fWeightName;
499 std::vector<std::size_t> fWeightShape;
500 std::size_t fWeightSize;
501
506 fWeightSize = 1;
507 fWeightShape.clear();
508 for(int j=0; j<PyArray_NDIM(fWeightTensor); ++j){
509 fWeightShape.push_back((std::size_t)(PyArray_DIM(fWeightTensor,j)));
510 fWeightSize*=(std::size_t)(PyArray_DIM(fWeightTensor,j));
511 }
512 switch(fWeightDType){
513 case ETensorType::FLOAT:{
514 float* fWeightValue = (float*)PyArray_DATA(fWeightTensor);
515 std::shared_ptr<void> fData(malloc(fWeightSize * sizeof(float)), free);
516 std::memcpy(fData.get(),fWeightValue,fWeightSize * sizeof(float));
517 rmodel.AddInitializedTensor(fWeightName, ETensorType::FLOAT,fWeightShape,fData);
518 break;
519 }
520 default:
521 throw std::runtime_error("Type error: TMVA SOFIE does not yet supports weights of data type"+ConvertTypeToString(fWeightDType));
522 }
523 }
524
525
526 //Extracting Input tensor info
527 PyRunString("inputs=[x for x in model.graph.inputs()]",fGlobalNS,fLocalNS);
528 PyRunString("inputs=inputs[1:]",fGlobalNS,fLocalNS);
529 PyRunString("inputNames=[x.debugName() for x in inputs]",fGlobalNS,fLocalNS);
530 PyObject* fPInputs= PyDict_GetItemString(fLocalNS,"inputNames");
531 std::string fInputName;
532 std::vector<size_t>fInputShape;
535 fInputName = PyStringAsString(PyList_GetItem(fPInputs,inputIter));
536 fInputShape = inputShapes[inputIter];
538 switch(fInputDType){
539 case(ETensorType::FLOAT): {
540 rmodel.AddInputTensorInfo(fInputName, ETensorType::FLOAT, fInputShape);
541 rmodel.AddInputTensorName(fInputName);
542 break;
543 }
544 default:
545 throw std::runtime_error("Type Error: TMVA SOFIE does not yet support the input tensor data type"+ConvertTypeToString(fInputDType));
546 }
547 }
548
549
550 //Extracting output tensor names
551 PyRunString("outputs=[x for x in graph[0].outputs()]",fGlobalNS,fLocalNS);
552 PyRunString("outputNames=[x.debugName() for x in outputs]",fGlobalNS,fLocalNS);
553 PyObject* fPOutputs= PyDict_GetItemString(fLocalNS,"outputNames");
554 std::vector<std::string> fOutputNames;
556 fOutputNames.push_back(PyStringAsString(PyList_GetItem(fPOutputs,outputIter)));
557 }
558 rmodel.AddOutputTensorNameList(fOutputNames);
559
560 return rmodel;
561}
562
563//////////////////////////////////////////////////////////////////////////////////
564/// \param[in] filepath file location of PyTorch .pt model
565/// \param[in] inputShapes vector of input shape vectors
566/// \return Parsed RModel object
567///
568/// Overloaded Parser function for translating PyTorch .pt model to RModel object.
569/// Function only requires the inputShapes vector as a parameter. Function
570/// builds the vector of Data-types for the input tensors using Float as default,
571/// Function calls the `Parse()` function with the vector of data-types included,
572/// subsequently returning the parsed RModel object.
573RModel Parse(std::string filepath,std::vector<std::vector<size_t>> inputShapes){
574 std::vector<ETensorType> dtype(inputShapes.size(),ETensorType::FLOAT);
576}
577
578} // namespace TMVA::Experimental::SOFIE::PyTorch
int Py_ssize_t
Definition CPyCppyy.h:215
#define PyBytes_AsString
Definition CPyCppyy.h:64
int main()
Definition Prototype.cxx:12
_object PyObject
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
#define malloc
Definition civetweb.c:1575
const_iterator end() const
Basic string class.
Definition TString.h:138
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
std::unique_ptr< ROperator > MakePyTorchGemm(PyObject *fNode)
Prepares a ROperator_Gemm object.
std::unique_ptr< ROperator > MakePyTorchNode(PyObject *fNode)
Prepares equivalent ROperator with respect to PyTorch ONNX node.
std::unique_ptr< ROperator > MakePyTorchConv(PyObject *fNode)
Prepares a ROperator_Conv object.
std::unique_ptr< ROperator > MakePyTorchSigmoid(PyObject *fNode)
Prepares a ROperator_Sigmoid object.
std::unique_ptr< ROperator > MakePyTorchSelu(PyObject *fNode)
Prepares a ROperator_Selu object.
std::unique_ptr< ROperator > MakePyTorchRelu(PyObject *fNode)
Prepares a ROperator_Relu object.
std::unordered_map< std::string, std::unique_ptr< ROperator >(*)(PyObject *fNode)> PyTorchMethodMap
std::unique_ptr< ROperator > MakePyTorchTranspose(PyObject *fNode)
Prepares a ROperator_Transpose object.
RModel Parse(std::string filepath, std::vector< std::vector< size_t > > inputShapes, std::vector< ETensorType > dtype)
Parser function for translating PyTorch .pt model into a RModel object.
std::string ConvertTypeToString(ETensorType type)
ETensorType ConvertStringToType(std::string type)