Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_PyTorch.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Sanjiban Sengupta 2021
3
4/**********************************************************************************
5 * Project : TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package : TMVA *
7 * Function: TMVA::Experimental::SOFIE::PyTorch::Parse *
8 * *
9 * Description: *
10 * Parser function for translating PyTorch .pt model to RModel object *
11 * *
12 * Example Usage: *
13 * ~~~ {.cpp} *
14 * using TMVA::Experimental::SOFIE; *
15 * // Building the vector of input tensor shapes *
16 * std::vector<size_t> s1{120,1}; *
17 * std::vector<std::vector<size_t>> inputShape{s1}; *
18 * RModel model = PyTorch::Parse("trained_model_dense.pt",inputShape); *
19 * ~~~ *
20 * *
21 **********************************************************************************/
22
23
25
26#include <Python.h>
27
28#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
29#include <numpy/arrayobject.h>
30
32
33namespace {
34
35// Utility functions (taken from PyMethodBase in PyMVA)
36
37void PyRunString(TString code, PyObject *globalNS, PyObject *localNS)
38{
40 if (!fPyReturn) {
41 std::cout << "\nPython error message:\n";
43 throw std::runtime_error("\nFailed to run python code: " + code);
44 }
45}
46
47const char *PyStringAsString(PyObject *string)
48{
51 return cstring;
52}
53
54std::vector<size_t> GetDataFromList(PyObject *listObject)
55{
56 std::vector<size_t> listVec;
59 }
60 return listVec;
61}
62
63} // namespace
64
65
66namespace INTERNAL{
67
68// For searching and calling specific preparatory function for PyTorch ONNX Graph's node
69std::unique_ptr<ROperator> MakePyTorchNode(PyObject* fNode);
70
71std::unique_ptr<ROperator> MakePyTorchGemm(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Gemm operator
72std::unique_ptr<ROperator> MakePyTorchConv(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Conv operator
73std::unique_ptr<ROperator> MakePyTorchRelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Relu operator
74std::unique_ptr<ROperator> MakePyTorchSelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Selu operator
75std::unique_ptr<ROperator> MakePyTorchSigmoid(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Sigmoid operator
76std::unique_ptr<ROperator> MakePyTorchTranspose(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Transpose operator
77
78// For mapping PyTorch ONNX Graph's Node with the preparatory functions for ROperators
79using PyTorchMethodMap = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(PyObject* fNode)>;
80
82{
83 {"onnx::Gemm", &MakePyTorchGemm},
84 {"onnx::Conv", &MakePyTorchConv},
85 {"onnx::Relu", &MakePyTorchRelu},
86 {"onnx::Selu", &MakePyTorchSelu},
87 {"onnx::Sigmoid", &MakePyTorchSigmoid},
88 {"onnx::Transpose", &MakePyTorchTranspose}
89};
90
91
92//////////////////////////////////////////////////////////////////////////////////
93/// \brief Prepares equivalent ROperator with respect to PyTorch ONNX node.
94///
95/// \param[in] fNode Python PyTorch ONNX Graph node
96/// \return unique pointer to ROperator object
97///
98/// Function searches for the passed PyTorch ONNX Graph node in the map, and calls
99/// the specific preparatory function, subsequently returning the ROperator object.
100///
101/// For developing new preparatory functions for supporting PyTorch ONNX Graph nodes
102/// in future, all one needs is to extract the required properties and attributes
103/// from the fNode dictionary which contains all the information about any PyTorch ONNX
104// Graph node and after any required transformations, these are passed for instantiating
105/// the ROperator object.
106///
107/// The fNode dictionary which holds all the information about a PyTorch ONNX Graph's node has
108/// following structure:-
109///
110/// dict fNode { 'nodeType' : Type of node (operator)
111/// 'nodeAttributes' : Attributes of the node
112/// 'nodeInputs' : List of names of input tensors
113/// 'nodeOutputs' : List of names of output tensors
114/// 'nodeDType' : Data-type of the operator node
115/// }
116///
117std::unique_ptr<ROperator> MakePyTorchNode(PyObject* fNode){
118 std::string fNodeType = PyStringAsString(PyDict_GetItemString(fNode,"nodeType"));
119 auto findNode = mapPyTorchNode.find(fNodeType);
120 if(findNode == mapPyTorchNode.end()){
121 throw std::runtime_error("TMVA::SOFIE - Parsing PyTorch node " +fNodeType+" is not yet supported ");
122 }
123 return (findNode->second)(fNode);
124}
125
126
127//////////////////////////////////////////////////////////////////////////////////
128/// \brief Prepares a ROperator_Gemm object
129///
130/// \param[in] fNode Python PyTorch ONNX Graph node
131/// \return Unique pointer to ROperator object
132///
133/// For PyTorch's Linear layer having Gemm operation in its ONNX graph,
134/// the names of the input tensor, output tensor are extracted, and then
135/// are passed to instantiate a ROperator_Gemm object using the required attributes.
136/// fInputs is a list of tensor names, which includes the names of the input tensor
137/// and the weight tensors.
138std::unique_ptr<ROperator> MakePyTorchGemm(PyObject* fNode){
139 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
140 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
141 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
142 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
143
144 // Extracting the parameters for Gemm Operator
145 std::string fNameA = PyStringAsString(PyList_GetItem(fInputs,0));
146 std::string fNameB = PyStringAsString(PyList_GetItem(fInputs,1));
147 std::string fNameC = PyStringAsString(PyList_GetItem(fInputs,2));
148 std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0));
149 float fAttrAlpha = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"alpha")));
150 float fAttrBeta = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"beta")));
151 int_t fAttrTransA;
152 int_t fAttrTransB;
153
154 if(PyDict_Contains(fAttributes,PyUnicode_FromString("transB"))){
155 fAttrTransB = (int_t)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"transB")));
156 fAttrTransA = !fAttrTransB;
157 }
158 else{
159 fAttrTransA=(int_t)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"transA")));
160 fAttrTransB = !fAttrTransA;
161 }
162
163 std::unique_ptr<ROperator> op;
165 case ETensorType::FLOAT: {
166 op.reset(new ROperator_Gemm<float>(fAttrAlpha, fAttrBeta, fAttrTransA, fAttrTransB, fNameA, fNameB, fNameC, fNameY ));
167 break;
168 }
169 default:
170 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Gemm does not yet support input type " + fNodeDType);
171 }
172 return op;
173}
174
175//////////////////////////////////////////////////////////////////////////////////
176/// \brief Prepares a ROperator_Relu object
177///
178/// \param[in] fNode Python PyTorch ONNX Graph node
179/// \return Unique pointer to ROperator object
180///
181/// For instantiating a ROperator_Relu object, the names of
182/// input & output tensors and the data-type of the Graph node
183/// are extracted.
184std::unique_ptr<ROperator> MakePyTorchRelu(PyObject* fNode){
185 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
186 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
187 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
188 std::string fNameX = PyStringAsString(PyList_GetItem(fInputs,0));
189 std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0));
190 std::unique_ptr<ROperator> op;
192 case ETensorType::FLOAT: {
194 break;
195 }
196 default:
197 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Relu does not yet support input type " + fNodeDType);
198 }
199 return op;
200}
201
202//////////////////////////////////////////////////////////////////////////////////
203/// \brief Prepares a ROperator_Selu object
204///
205/// \param[in] fNode Python PyTorch ONNX Graph node
206/// \return Unique pointer to ROperator object
207///
208/// For instantiating a ROperator_Selu object, the names of
209/// input & output tensors and the data-type of the Graph node
210/// are extracted.
211std::unique_ptr<ROperator> MakePyTorchSelu(PyObject* fNode){
212 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
213 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
214 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
215
216 std::unique_ptr<ROperator> op;
218 case ETensorType::FLOAT: {
219 op.reset(new ROperator_Selu<float>(PyStringAsString(PyList_GetItem(fInputs,0)), PyStringAsString(PyList_GetItem(fOutputs,0))));
220 break;
221 }
222 default:
223 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Selu does not yet support input type " + fNodeDType);
224 }
225 return op;
226}
227
228//////////////////////////////////////////////////////////////////////////////////
229/// \brief Prepares a ROperator_Sigmoid object
230///
231/// \param[in] fNode Python PyTorch ONNX Graph node
232/// \return Unique pointer to ROperator object
233///
234/// For instantiating a ROperator_Sigmoid object, the names of
235/// input & output tensors and the data-type of the Graph node
236/// are extracted.
237std::unique_ptr<ROperator> MakePyTorchSigmoid(PyObject* fNode){
238 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
239 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
240 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
241
242 std::unique_ptr<ROperator> op;
244 case ETensorType::FLOAT: {
245 op.reset(new ROperator_Sigmoid<float>(PyStringAsString(PyList_GetItem(fInputs,0)), PyStringAsString(PyList_GetItem(fOutputs,0))));
246 break;
247 }
248 default:
249 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Sigmoid does not yet support input type " + fNodeDType);
250 }
251 return op;
252}
253
254
255//////////////////////////////////////////////////////////////////////////////////
256/// \brief Prepares a ROperator_Transpose object
257///
258/// \param[in] fNode Python PyTorch ONNX Graph node
259/// \return Unique pointer to ROperator object
260///
261/// For Transpose Operator of PyTorch's ONNX Graph, the permute dimensions are found,
262/// and are passed in instantiating the ROperator object.
263std::unique_ptr<ROperator> MakePyTorchTranspose(PyObject* fNode){
264 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
265 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
266 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
267 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
268
269 // Extracting the Permute dimensions for transpose
270 std::vector<int_t> fAttrPermute;
271 PyObject* fPermute=PyDict_GetItemString(fAttributes,"perm");
274 }
275 std::string fNameData = PyStringAsString(PyList_GetItem(fInputs,0));
276 std::string fNameOutput = PyStringAsString(PyList_GetItem(fOutputs,0));
277
278 std::unique_ptr<ROperator> op;
280 case ETensorType::FLOAT: {
282 break;
283 }
284 default:
285 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Transpose does not yet support input type " + fNodeDType);
286 }
287 return op;
288}
289
290
291//////////////////////////////////////////////////////////////////////////////////
292/// \brief Prepares a ROperator_Conv object
293///
294/// \param[in] fNode Python PyTorch ONNX Graph node
295/// \return Unique pointer to ROperator object
296///
297/// For Conv Operator of PyTorch's ONNX Graph, attributes like dilations, group,
298/// kernel shape, pads and strides are found, and are passed in instantiating the
299/// ROperator object with autopad default to `NOTSET`.
300std::unique_ptr<ROperator> MakePyTorchConv(PyObject* fNode){
301 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
302 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
303 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
304 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
305
306 // Extracting the Conv Node Attributes
307 PyObject* fDilations = PyDict_GetItemString(fAttributes,"dilations");
308 PyObject* fGroup = PyDict_GetItemString(fAttributes,"group");
309 PyObject* fKernelShape = PyDict_GetItemString(fAttributes,"kernel_shape");
310 PyObject* fPads = PyDict_GetItemString(fAttributes,"pads");
311 PyObject* fStrides = PyDict_GetItemString(fAttributes,"strides");
312
313 std::string fAttrAutopad = "NOTSET";
314 std::vector<size_t> fAttrDilations = GetDataFromList(fDilations);
315 size_t fAttrGroup = PyLong_AsLong(fGroup);
316 std::vector<size_t> fAttrKernelShape = GetDataFromList(fKernelShape);
317 std::vector<size_t> fAttrPads = GetDataFromList(fPads);
318 std::vector<size_t> fAttrStrides = GetDataFromList(fStrides);
319 std::string nameX = PyStringAsString(PyList_GetItem(fInputs,0));
320 std::string nameW = PyStringAsString(PyList_GetItem(fInputs,1));
321 std::string nameB = PyStringAsString(PyList_GetItem(fInputs,2));
322 std::string nameY = PyStringAsString(PyList_GetItem(fOutputs,0));
323
324 std::unique_ptr<ROperator> op;
326 case ETensorType::FLOAT: {
327 op.reset(new ROperator_Conv<float>(fAttrAutopad, fAttrDilations, fAttrGroup, fAttrKernelShape, fAttrPads, fAttrStrides, nameX, nameW, nameB, nameY));
328 break;
329 }
330 default:
331 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Conv does not yet support input type " + fNodeDType);
332 }
333 return op;
334}
335}//INTERNAL
336
337
338//////////////////////////////////////////////////////////////////////////////////
339/// \param[in] filename file location of PyTorch .pt model
340/// \param[in] inputShapes vector of input shape vectors
341/// \param[in] inputDTypes vector of ETensorType for data-types of Input tensors
342/// \return Parsed RModel object
343///
344/// The `Parse()` function defined in `TMVA::Experimental::SOFIE::PyTorch` will
345/// parse a trained PyTorch .pt model into a RModel Object. The parser uses
346/// internal functions of PyTorch to convert any PyTorch model into its
347/// equivalent ONNX Graph. For this conversion, dummy inputs are built which are
348/// passed through the model and the applied operators are recorded for populating
349/// the ONNX graph. The `Parse()` function requires the shapes and data-types of
350/// the input tensors which are used for building the dummy inputs.
351/// After the said conversion, the nodes of the ONNX graph are then traversed to
352/// extract properties like Node type, Attributes, input & output tensor names.
353/// Function `AddOperator()` is then called on the extracted nodes to add the
354/// operator into the RModel object. The nodes are also checked for adding any
355/// required routines for executing the generated Inference code.
356///
357/// The internal function used to convert the model to graph object returns a list
358/// which contains a Graph object and a dictionary of weights. This dictionary is
359/// used to extract the Initialized tensors for the model. The names and data-types
360/// of the Initialized tensors are extracted along with their values in NumPy array,
361/// and after approapriate type-conversions, they are added into the RModel object.
362///
363/// For adding the Input tensor infos, the names of the input tensors are extracted
364/// from the PyTorch ONNX graph object. The vector of shapes & data-types passed
365/// into the `Parse()` function are used to extract the data-type and the shape
366/// of the input tensors. Extracted input tensor infos are then added into the
367/// RModel object by calling the `AddInputTensorInfo()` function.
368///
369/// For the output tensor infos, names of the output tensors are also extracted
370/// from the Graph object and are then added into the RModel object by calling the
371/// AddOutputTensorNameList() function.
372///
373/// Example Usage:
374/// ~~~ {.cpp}
375/// using TMVA::Experimental::SOFIE;
376/// //Building the vector of input tensor shapes
377/// std::vector<size_t> s1{120,1};
378/// std::vector<std::vector<size_t>> inputShape{s1};
379/// RModel model = PyTorch::Parse("trained_model_dense.pt",inputShape);
380/// ~~~
381RModel Parse(std::string filename, std::vector<std::vector<size_t>> inputShapes, std::vector<ETensorType> inputDTypes){
382
383 char sep = '/';
384 #ifdef _WIN32
385 sep = '\\';
386 #endif
387
388 size_t isep = filename.rfind(sep, filename.length());
389 std::string filename_nodir = filename;
390 if (isep != std::string::npos){
391 filename_nodir = (filename.substr(isep+1, filename.length() - isep));
392 }
393
394 //Check on whether the PyTorch .pt file exists
395 if(!std::ifstream(filename).good()){
396 throw std::runtime_error("Model file "+filename_nodir+" not found!");
397 }
398
399
400 std::time_t ttime = std::time(0);
401 std::tm* gmt_time = std::gmtime(&ttime);
402 std::string parsetime (std::asctime(gmt_time));
403
405
406 //Intializing Python Interpreter and scope dictionaries
408 PyObject* main = PyImport_AddModule("__main__");
409 PyObject* fGlobalNS = PyModule_GetDict(main);
410 PyObject* fLocalNS = PyDict_New();
411 if (!fGlobalNS) {
412 throw std::runtime_error("Can't init global namespace for Python");
413 }
414 if (!fLocalNS) {
415 throw std::runtime_error("Can't init local namespace for Python");
416 }
417
418
419 //Extracting model information
420 //Model is converted to ONNX graph format
421 //using PyTorch's internal function with the input shape provided
422 PyRunString("import torch",fGlobalNS,fLocalNS);
423 PyRunString("print('Torch Version: '+torch.__version__)",fGlobalNS,fLocalNS);
424 PyRunString("from torch.onnx.utils import _model_to_graph",fGlobalNS,fLocalNS);
425 //PyRunString("from torch.onnx.symbolic_helper import _set_onnx_shape_inference",fGlobalNS,fLocalNS);
426 PyRunString(TString::Format("model= torch.jit.load('%s')",filename.c_str()),fGlobalNS,fLocalNS);
427 PyRunString("globals().update(locals())",fGlobalNS,fLocalNS);
428 PyRunString("model.cpu()",fGlobalNS,fLocalNS);
429 PyRunString("model.eval()",fGlobalNS,fLocalNS);
430
431 //Building dummy inputs for the model
432 PyRunString("dummyInputs=[]",fGlobalNS,fLocalNS);
433 for(long unsigned int it=0;it<inputShapes.size();++it){
434 PyRunString("inputShape=[]",fGlobalNS,fLocalNS);
435 for(long unsigned int itr=0;itr<inputShapes[it].size();++itr){
436 PyRunString(TString::Format("inputShape.append(%d)",(int)inputShapes[it][itr]),fGlobalNS,fLocalNS);
437 }
438 PyRunString("dummyInputs.append(torch.rand(*inputShape))",fGlobalNS,fLocalNS);
439 }
440
441
442 //Getting the ONNX graph from model using the dummy inputs and example outputs
443 //PyRunString("_set_onnx_shape_inference(True)",fGlobalNS,fLocalNS);
444 PyRunString("graph=_model_to_graph(model,dummyInputs)",fGlobalNS,fLocalNS);
445
446
447 //Extracting the model information in list modelData
448 PyRunString("modelData=[]",fGlobalNS,fLocalNS);
449 // The '_node_get' helper function is used to avoid dependency on onnx submodule
450 // (for the subscript operator of torch._C.Node), as done in https://github.com/pytorch/pytorch/pull/82628
451 PyRunString("def _node_get(node, key):\n"
452 " sel = node.kindOf(key)\n"
453 " return getattr(node, sel)(key)\n",
454 fGlobalNS, fLocalNS);
455 PyRunString("for i in graph[0].nodes():\n"
456 " globals().update(locals())\n"
457 " nodeData={}\n"
458 " nodeData['nodeType']=i.kind()\n"
459 " nodeAttributeNames=[x for x in i.attributeNames()]\n"
460 " nodeAttributes={j: _node_get(i, j) for j in nodeAttributeNames}\n"
461 " nodeData['nodeAttributes']=nodeAttributes\n"
462 " nodeInputs=[x for x in i.inputs()]\n"
463 " nodeInputNames=[x.debugName() for x in nodeInputs]\n"
464 " nodeData['nodeInputs']=nodeInputNames\n"
465 " nodeOutputs=[x for x in i.outputs()]\n"
466 " nodeOutputNames=[x.debugName() for x in nodeOutputs]\n"
467 " nodeData['nodeOutputs']=nodeOutputNames\n"
468 " nodeDType=[x.type().scalarType() for x in nodeOutputs]\n"
469 " nodeData['nodeDType']=nodeDType\n"
470 " modelData.append(nodeData)",
471 fGlobalNS, fLocalNS);
472
473 PyObject* fPModel = PyDict_GetItemString(fLocalNS,"modelData");
475 PyObject *fNode;
476 std::string fNodeType;
477
478 //Adding operators into the RModel object
481 fNodeType = PyStringAsString(PyDict_GetItemString(fNode,"nodeType"));
482
483 // Adding required routines for inference code generation
484 if(fNodeType == "onnx::Gemm"){
485 rmodel.AddBlasRoutines({"Gemm", "Gemv"});
486 }
487 else if(fNodeType == "onnx::Selu" || fNodeType == "onnx::Sigmoid"){
488 rmodel.AddNeededStdLib("cmath");
489 }
490 else if (fNodeType == "onnx::Conv") {
491 rmodel.AddBlasRoutines({"Gemm", "Axpy"});
492 }
493 rmodel.AddOperator(INTERNAL::MakePyTorchNode(fNode));
494 }
495
496
497 //Extracting model weights to add the initialized tensors to the RModel
498 PyRunString("weightNames=[k for k in graph[1].keys()]",fGlobalNS,fLocalNS);
499 PyRunString("weights=[v.numpy() for v in graph[1].values()]",fGlobalNS,fLocalNS);
500 PyRunString("weightDTypes=[v.type()[6:-6] for v in graph[1].values()]",fGlobalNS,fLocalNS);
501 PyObject* fPWeightNames = PyDict_GetItemString(fLocalNS,"weightNames");
502 PyObject* fPWeightTensors = PyDict_GetItemString(fLocalNS,"weights");
503 PyObject* fPWeightDTypes = PyDict_GetItemString(fLocalNS,"weightDTypes");
505 std::string fWeightName;
507 std::vector<std::size_t> fWeightShape;
508 std::size_t fWeightSize;
509
514 fWeightSize = 1;
515 fWeightShape.clear();
516 for(int j=0; j<PyArray_NDIM(fWeightTensor); ++j){
517 fWeightShape.push_back((std::size_t)(PyArray_DIM(fWeightTensor,j)));
518 fWeightSize*=(std::size_t)(PyArray_DIM(fWeightTensor,j));
519 }
520 switch(fWeightDType){
521 case ETensorType::FLOAT:{
522 float* fWeightValue = (float*)PyArray_DATA(fWeightTensor);
523 std::shared_ptr<void> fData(malloc(fWeightSize * sizeof(float)), free);
524 std::memcpy(fData.get(),fWeightValue,fWeightSize * sizeof(float));
525 rmodel.AddInitializedTensor(fWeightName, ETensorType::FLOAT,fWeightShape,fData);
526 break;
527 }
528 default:
529 throw std::runtime_error("Type error: TMVA SOFIE does not yet supports weights of data type"+ConvertTypeToString(fWeightDType));
530 }
531 }
532
533
534 //Extracting Input tensor info
535 PyRunString("inputs=[x for x in model.graph.inputs()]",fGlobalNS,fLocalNS);
536 PyRunString("inputs=inputs[1:]",fGlobalNS,fLocalNS);
537 PyRunString("inputNames=[x.debugName() for x in inputs]",fGlobalNS,fLocalNS);
538 PyObject* fPInputs= PyDict_GetItemString(fLocalNS,"inputNames");
539 std::string fInputName;
540 std::vector<size_t>fInputShape;
543 fInputName = PyStringAsString(PyList_GetItem(fPInputs,inputIter));
544 fInputShape = inputShapes[inputIter];
546 switch(fInputDType){
547 case(ETensorType::FLOAT): {
548 rmodel.AddInputTensorInfo(fInputName, ETensorType::FLOAT, fInputShape);
549 rmodel.AddInputTensorName(fInputName);
550 break;
551 }
552 default:
553 throw std::runtime_error("Type Error: TMVA SOFIE does not yet support the input tensor data type"+ConvertTypeToString(fInputDType));
554 }
555 }
556
557
558 //Extracting output tensor names
559 PyRunString("outputs=[x for x in graph[0].outputs()]",fGlobalNS,fLocalNS);
560 PyRunString("outputNames=[x.debugName() for x in outputs]",fGlobalNS,fLocalNS);
561 PyObject* fPOutputs= PyDict_GetItemString(fLocalNS,"outputNames");
562 std::vector<std::string> fOutputNames;
564 fOutputNames.push_back(PyStringAsString(PyList_GetItem(fPOutputs,outputIter)));
565 }
566 rmodel.AddOutputTensorNameList(fOutputNames);
567
568 return rmodel;
569}
570
571//////////////////////////////////////////////////////////////////////////////////
572/// \param[in] filepath file location of PyTorch .pt model
573/// \param[in] inputShapes vector of input shape vectors
574/// \return Parsed RModel object
575///
576/// Overloaded Parser function for translating PyTorch .pt model to RModel object.
577/// Function only requires the inputShapes vector as a parameter. Function
578/// builds the vector of Data-types for the input tensors using Float as default,
579/// Function calls the `Parse()` function with the vector of data-types included,
580/// subsequently returning the parsed RModel object.
581RModel Parse(std::string filepath,std::vector<std::vector<size_t>> inputShapes){
582 std::vector<ETensorType> dtype(inputShapes.size(),ETensorType::FLOAT);
584}
585
586} // namespace TMVA::Experimental::SOFIE::PyTorch
int Py_ssize_t
Definition CPyCppyy.h:215
#define PyBytes_AsString
Definition CPyCppyy.h:64
int main()
Definition Prototype.cxx:12
_object PyObject
#define Py_single_input
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
#define malloc
Definition civetweb.c:1575
const_iterator end() const
Basic string class.
Definition TString.h:138
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
std::unique_ptr< ROperator > MakePyTorchGemm(PyObject *fNode)
Prepares a ROperator_Gemm object.
std::unique_ptr< ROperator > MakePyTorchNode(PyObject *fNode)
Prepares equivalent ROperator with respect to PyTorch ONNX node.
std::unique_ptr< ROperator > MakePyTorchConv(PyObject *fNode)
Prepares a ROperator_Conv object.
std::unique_ptr< ROperator > MakePyTorchSigmoid(PyObject *fNode)
Prepares a ROperator_Sigmoid object.
std::unique_ptr< ROperator > MakePyTorchSelu(PyObject *fNode)
Prepares a ROperator_Selu object.
std::unique_ptr< ROperator > MakePyTorchRelu(PyObject *fNode)
Prepares a ROperator_Relu object.
std::unordered_map< std::string, std::unique_ptr< ROperator >(*)(PyObject *fNode)> PyTorchMethodMap
std::unique_ptr< ROperator > MakePyTorchTranspose(PyObject *fNode)
Prepares a ROperator_Transpose object.
RModel Parse(std::string filepath, std::vector< std::vector< size_t > > inputShapes, std::vector< ETensorType > dtype)
Parser function for translating PyTorch .pt model into a RModel object.
std::string ConvertTypeToString(ETensorType type)
ETensorType ConvertStringToType(std::string type)