Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_PyTorch.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Sanjiban Sengupta 2021
3
4/**********************************************************************************
5 * Project : TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package : TMVA *
7 * Function: TMVA::Experimental::SOFIE::PyTorch::Parse *
8 * *
9 * Description: *
10 * Parser function for translating PyTorch .pt model to RModel object *
11 * *
12 * Example Usage: *
13 * ~~~ {.cpp} *
14 * using TMVA::Experimental::SOFIE; *
15 * // Building the vector of input tensor shapes *
16 * std::vector<size_t> s1{120,1}; *
17 * std::vector<std::vector<size_t>> inputShape{s1}; *
18 * RModel model = PyTorch::Parse("trained_model_dense.pt",inputShape); *
19 * ~~~ *
20 * *
21 **********************************************************************************/
22
23
25
26#include <Python.h>
27
29
30namespace {
31
32// Utility functions (taken from PyMethodBase in PyMVA)
33
34void PyRunString(TString code, PyObject *globalNS, PyObject *localNS)
35{
37 if (!fPyReturn) {
38 std::cout << "\nPython error message:\n";
40 throw std::runtime_error("\nFailed to run python code: " + code);
41 }
42}
43
44const char *PyStringAsString(PyObject *string)
45{
48 return cstring;
49}
50
51std::vector<size_t> GetDataFromList(PyObject *listObject)
52{
53 std::vector<size_t> listVec;
56 }
57 return listVec;
58}
59
60} // namespace
61
62
63namespace INTERNAL{
64
65// For searching and calling specific preparatory function for PyTorch ONNX Graph's node
66std::unique_ptr<ROperator> MakePyTorchNode(PyObject* fNode);
67
68std::unique_ptr<ROperator> MakePyTorchGemm(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Gemm operator
69std::unique_ptr<ROperator> MakePyTorchConv(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Conv operator
70std::unique_ptr<ROperator> MakePyTorchRelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Relu operator
71std::unique_ptr<ROperator> MakePyTorchSelu(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Selu operator
72std::unique_ptr<ROperator> MakePyTorchSigmoid(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Sigmoid operator
73std::unique_ptr<ROperator> MakePyTorchTranspose(PyObject* fNode); // For instantiating ROperator for PyTorch ONNX's Transpose operator
74
75// For mapping PyTorch ONNX Graph's Node with the preparatory functions for ROperators
76using PyTorchMethodMap = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(PyObject* fNode)>;
77
79{
80 {"onnx::Gemm", &MakePyTorchGemm},
81 {"onnx::Conv", &MakePyTorchConv},
82 {"onnx::Relu", &MakePyTorchRelu},
83 {"onnx::Selu", &MakePyTorchSelu},
84 {"onnx::Sigmoid", &MakePyTorchSigmoid},
85 {"onnx::Transpose", &MakePyTorchTranspose}
86};
87
88
89//////////////////////////////////////////////////////////////////////////////////
90/// \brief Prepares equivalent ROperator with respect to PyTorch ONNX node.
91///
92/// \param[in] fNode Python PyTorch ONNX Graph node
93/// \return unique pointer to ROperator object
94///
95/// Function searches for the passed PyTorch ONNX Graph node in the map, and calls
96/// the specific preparatory function, subsequently returning the ROperator object.
97///
98/// For developing new preparatory functions for supporting PyTorch ONNX Graph nodes
99/// in future, all one needs is to extract the required properties and attributes
100/// from the fNode dictionary which contains all the information about any PyTorch ONNX
101// Graph node and after any required transformations, these are passed for instantiating
102/// the ROperator object.
103///
104/// The fNode dictionary which holds all the information about a PyTorch ONNX Graph's node has
105/// following structure:-
106///
107/// dict fNode { 'nodeType' : Type of node (operator)
108/// 'nodeAttributes' : Attributes of the node
109/// 'nodeInputs' : List of names of input tensors
110/// 'nodeOutputs' : List of names of output tensors
111/// 'nodeDType' : Data-type of the operator node
112/// }
113///
114std::unique_ptr<ROperator> MakePyTorchNode(PyObject* fNode){
115 std::string fNodeType = PyStringAsString(PyDict_GetItemString(fNode,"nodeType"));
116 auto findNode = mapPyTorchNode.find(fNodeType);
117 if(findNode == mapPyTorchNode.end()){
118 throw std::runtime_error("TMVA::SOFIE - Parsing PyTorch node " +fNodeType+" is not yet supported ");
119 }
120 return (findNode->second)(fNode);
121}
122
123
124//////////////////////////////////////////////////////////////////////////////////
125/// \brief Prepares a ROperator_Gemm object
126///
127/// \param[in] fNode Python PyTorch ONNX Graph node
128/// \return Unique pointer to ROperator object
129///
130/// For PyTorch's Linear layer having Gemm operation in its ONNX graph,
131/// the names of the input tensor, output tensor are extracted, and then
132/// are passed to instantiate a ROperator_Gemm object using the required attributes.
133/// fInputs is a list of tensor names, which includes the names of the input tensor
134/// and the weight tensors.
135std::unique_ptr<ROperator> MakePyTorchGemm(PyObject* fNode){
136 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
137 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
138 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
139 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
140
141 // Extracting the parameters for Gemm Operator
142 std::string fNameA = PyStringAsString(PyList_GetItem(fInputs,0));
143 std::string fNameB = PyStringAsString(PyList_GetItem(fInputs,1));
144 std::string fNameC = PyStringAsString(PyList_GetItem(fInputs,2));
145 std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0));
146 float fAttrAlpha = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"alpha")));
147 float fAttrBeta = (float)(PyFloat_AsDouble(PyDict_GetItemString(fAttributes,"beta")));
148 int_t fAttrTransA;
149 int_t fAttrTransB;
150
151 if(PyDict_Contains(fAttributes,PyUnicode_FromString("transB"))){
152 fAttrTransB = (int_t)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"transB")));
153 fAttrTransA = !fAttrTransB;
154 }
155 else{
156 fAttrTransA=(int_t)(PyLong_AsLong(PyDict_GetItemString(fAttributes,"transA")));
157 fAttrTransB = !fAttrTransA;
158 }
159
160 std::unique_ptr<ROperator> op;
162 case ETensorType::FLOAT: {
163 op.reset(new ROperator_Gemm<float>(fAttrAlpha, fAttrBeta, fAttrTransA, fAttrTransB, fNameA, fNameB, fNameC, fNameY ));
164 break;
165 }
166 default:
167 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Gemm does not yet support input type " + fNodeDType);
168 }
169 return op;
170}
171
172//////////////////////////////////////////////////////////////////////////////////
173/// \brief Prepares a ROperator_Relu object
174///
175/// \param[in] fNode Python PyTorch ONNX Graph node
176/// \return Unique pointer to ROperator object
177///
178/// For instantiating a ROperator_Relu object, the names of
179/// input & output tensors and the data-type of the Graph node
180/// are extracted.
181std::unique_ptr<ROperator> MakePyTorchRelu(PyObject* fNode){
182 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
183 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
184 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
185 std::string fNameX = PyStringAsString(PyList_GetItem(fInputs,0));
186 std::string fNameY = PyStringAsString(PyList_GetItem(fOutputs,0));
187 std::unique_ptr<ROperator> op;
189 case ETensorType::FLOAT: {
191 break;
192 }
193 default:
194 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Relu does not yet support input type " + fNodeDType);
195 }
196 return op;
197}
198
199//////////////////////////////////////////////////////////////////////////////////
200/// \brief Prepares a ROperator_Selu object
201///
202/// \param[in] fNode Python PyTorch ONNX Graph node
203/// \return Unique pointer to ROperator object
204///
205/// For instantiating a ROperator_Selu object, the names of
206/// input & output tensors and the data-type of the Graph node
207/// are extracted.
208std::unique_ptr<ROperator> MakePyTorchSelu(PyObject* fNode){
209 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
210 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
211 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
212
213 std::unique_ptr<ROperator> op;
215 case ETensorType::FLOAT: {
216 op.reset(new ROperator_Selu<float>(PyStringAsString(PyList_GetItem(fInputs,0)), PyStringAsString(PyList_GetItem(fOutputs,0))));
217 break;
218 }
219 default:
220 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Selu does not yet support input type " + fNodeDType);
221 }
222 return op;
223}
224
225//////////////////////////////////////////////////////////////////////////////////
226/// \brief Prepares a ROperator_Sigmoid object
227///
228/// \param[in] fNode Python PyTorch ONNX Graph node
229/// \return Unique pointer to ROperator object
230///
231/// For instantiating a ROperator_Sigmoid object, the names of
232/// input & output tensors and the data-type of the Graph node
233/// are extracted.
234std::unique_ptr<ROperator> MakePyTorchSigmoid(PyObject* fNode){
235 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
236 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
237 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
238
239 std::unique_ptr<ROperator> op;
241 case ETensorType::FLOAT: {
242 op.reset(new ROperator_Sigmoid<float>(PyStringAsString(PyList_GetItem(fInputs,0)), PyStringAsString(PyList_GetItem(fOutputs,0))));
243 break;
244 }
245 default:
246 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Sigmoid does not yet support input type " + fNodeDType);
247 }
248 return op;
249}
250
251
252//////////////////////////////////////////////////////////////////////////////////
253/// \brief Prepares a ROperator_Transpose object
254///
255/// \param[in] fNode Python PyTorch ONNX Graph node
256/// \return Unique pointer to ROperator object
257///
258/// For Transpose Operator of PyTorch's ONNX Graph, the permute dimensions are found,
259/// and are passed in instantiating the ROperator object.
260std::unique_ptr<ROperator> MakePyTorchTranspose(PyObject* fNode){
261 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
262 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
263 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
264 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
265
266 // Extracting the Permute dimensions for transpose
267 std::vector<int_t> fAttrPermute;
268 PyObject* fPermute=PyDict_GetItemString(fAttributes,"perm");
271 }
272 std::string fNameData = PyStringAsString(PyList_GetItem(fInputs,0));
273 std::string fNameOutput = PyStringAsString(PyList_GetItem(fOutputs,0));
274
275 std::unique_ptr<ROperator> op = std::make_unique<ROperator_Transpose>(fAttrPermute, fNameData, fNameOutput);
276 return op;
277}
278
279
280//////////////////////////////////////////////////////////////////////////////////
281/// \brief Prepares a ROperator_Conv object
282///
283/// \param[in] fNode Python PyTorch ONNX Graph node
284/// \return Unique pointer to ROperator object
285///
286/// For Conv Operator of PyTorch's ONNX Graph, attributes like dilations, group,
287/// kernel shape, pads and strides are found, and are passed in instantiating the
288/// ROperator object with autopad default to `NOTSET`.
289std::unique_ptr<ROperator> MakePyTorchConv(PyObject* fNode){
290 PyObject* fAttributes = PyDict_GetItemString(fNode,"nodeAttributes");
291 PyObject* fInputs = PyDict_GetItemString(fNode,"nodeInputs");
292 PyObject* fOutputs = PyDict_GetItemString(fNode,"nodeOutputs");
293 std::string fNodeDType = PyStringAsString(PyList_GetItem(PyDict_GetItemString(fNode,"nodeDType"),0));
294
295 // Extracting the Conv Node Attributes
296 PyObject* fDilations = PyDict_GetItemString(fAttributes,"dilations");
297 PyObject* fGroup = PyDict_GetItemString(fAttributes,"group");
298 PyObject* fKernelShape = PyDict_GetItemString(fAttributes,"kernel_shape");
299 PyObject* fPads = PyDict_GetItemString(fAttributes,"pads");
300 PyObject* fStrides = PyDict_GetItemString(fAttributes,"strides");
301
302 std::string fAttrAutopad = "NOTSET";
303 std::vector<size_t> fAttrDilations = GetDataFromList(fDilations);
304 size_t fAttrGroup = PyLong_AsLong(fGroup);
305 std::vector<size_t> fAttrKernelShape = GetDataFromList(fKernelShape);
306 std::vector<size_t> fAttrPads = GetDataFromList(fPads);
307 std::vector<size_t> fAttrStrides = GetDataFromList(fStrides);
308 std::string nameX = PyStringAsString(PyList_GetItem(fInputs,0));
309 std::string nameW = PyStringAsString(PyList_GetItem(fInputs,1));
310 std::string nameB = PyStringAsString(PyList_GetItem(fInputs,2));
311 std::string nameY = PyStringAsString(PyList_GetItem(fOutputs,0));
312
313 std::unique_ptr<ROperator> op;
315 case ETensorType::FLOAT: {
316 op.reset(new ROperator_Conv<float>(fAttrAutopad, fAttrDilations, fAttrGroup, fAttrKernelShape, fAttrPads, fAttrStrides, nameX, nameW, nameB, nameY));
317 break;
318 }
319 default:
320 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Conv does not yet support input type " + fNodeDType);
321 }
322 return op;
323}
324}//INTERNAL
325
326
327//////////////////////////////////////////////////////////////////////////////////
328/// \param[in] filename file location of PyTorch .pt model
329/// \param[in] inputShapes vector of input shape vectors
330/// \param[in] inputDTypes vector of ETensorType for data-types of Input tensors
331/// \return Parsed RModel object
332///
333/// The `Parse()` function defined in `TMVA::Experimental::SOFIE::PyTorch` will
334/// parse a trained PyTorch .pt model into a RModel Object. The parser uses
335/// internal functions of PyTorch to convert any PyTorch model into its
336/// equivalent ONNX Graph. For this conversion, dummy inputs are built which are
337/// passed through the model and the applied operators are recorded for populating
338/// the ONNX graph. The `Parse()` function requires the shapes and data-types of
339/// the input tensors which are used for building the dummy inputs.
340/// After the said conversion, the nodes of the ONNX graph are then traversed to
341/// extract properties like Node type, Attributes, input & output tensor names.
342/// Function `AddOperator()` is then called on the extracted nodes to add the
343/// operator into the RModel object. The nodes are also checked for adding any
344/// required routines for executing the generated Inference code.
345///
346/// The internal function used to convert the model to graph object returns a list
347/// which contains a Graph object and a dictionary of weights. This dictionary is
348/// used to extract the Initialized tensors for the model. The names and data-types
349/// of the Initialized tensors are extracted along with their values in NumPy array,
350/// and after approapriate type-conversions, they are added into the RModel object.
351///
352/// For adding the Input tensor infos, the names of the input tensors are extracted
353/// from the PyTorch ONNX graph object. The vector of shapes & data-types passed
354/// into the `Parse()` function are used to extract the data-type and the shape
355/// of the input tensors. Extracted input tensor infos are then added into the
356/// RModel object by calling the `AddInputTensorInfo()` function.
357///
358/// For the output tensor infos, names of the output tensors are also extracted
359/// from the Graph object and are then added into the RModel object by calling the
360/// AddOutputTensorNameList() function.
361///
362/// Example Usage:
363/// ~~~ {.cpp}
364/// using TMVA::Experimental::SOFIE;
365/// //Building the vector of input tensor shapes
366/// std::vector<size_t> s1{120,1};
367/// std::vector<std::vector<size_t>> inputShape{s1};
368/// RModel model = PyTorch::Parse("trained_model_dense.pt",inputShape);
369/// ~~~
370RModel Parse(std::string filename, std::vector<std::vector<size_t>> inputShapes, std::vector<ETensorType> inputDTypes){
371
372 char sep = '/';
373 #ifdef _WIN32
374 sep = '\\';
375 #endif
376
377 size_t isep = filename.rfind(sep, filename.length());
378 std::string filename_nodir = filename;
379 if (isep != std::string::npos){
380 filename_nodir = (filename.substr(isep+1, filename.length() - isep));
381 }
382
383 //Check on whether the PyTorch .pt file exists
384 if(!std::ifstream(filename).good()){
385 throw std::runtime_error("Model file "+filename_nodir+" not found!");
386 }
387
388
389 std::time_t ttime = std::time(0);
390 std::tm* gmt_time = std::gmtime(&ttime);
391 std::string parsetime (std::asctime(gmt_time));
392
394
395 //Intializing Python Interpreter and scope dictionaries
397 PyObject* main = PyImport_AddModule("__main__");
398 PyObject* fGlobalNS = PyModule_GetDict(main);
399 PyObject* fLocalNS = PyDict_New();
400 if (!fGlobalNS) {
401 throw std::runtime_error("Can't init global namespace for Python");
402 }
403 if (!fLocalNS) {
404 throw std::runtime_error("Can't init local namespace for Python");
405 }
406
407
408 //Extracting model information
409 //Model is converted to ONNX graph format
410 //using PyTorch's internal function with the input shape provided
411 PyRunString("import torch",fGlobalNS,fLocalNS);
412 PyRunString("print('Torch Version: '+torch.__version__)",fGlobalNS,fLocalNS);
413 PyRunString("from torch.onnx.utils import _model_to_graph",fGlobalNS,fLocalNS);
414 //PyRunString("from torch.onnx.symbolic_helper import _set_onnx_shape_inference",fGlobalNS,fLocalNS);
415 PyRunString(TString::Format("model= torch.jit.load('%s')",filename.c_str()),fGlobalNS,fLocalNS);
416 PyRunString("globals().update(locals())",fGlobalNS,fLocalNS);
417 PyRunString("model.cpu()",fGlobalNS,fLocalNS);
418 PyRunString("model.eval()",fGlobalNS,fLocalNS);
419
420 //Building dummy inputs for the model
421 PyRunString("dummyInputs=[]",fGlobalNS,fLocalNS);
422 for(long unsigned int it=0;it<inputShapes.size();++it){
423 PyRunString("inputShape=[]",fGlobalNS,fLocalNS);
424 for(long unsigned int itr=0;itr<inputShapes[it].size();++itr){
425 PyRunString(TString::Format("inputShape.append(%d)",(int)inputShapes[it][itr]),fGlobalNS,fLocalNS);
426 }
427 PyRunString("dummyInputs.append(torch.rand(*inputShape))",fGlobalNS,fLocalNS);
428 }
429
430
431 //Getting the ONNX graph from model using the dummy inputs and example outputs
432 //PyRunString("_set_onnx_shape_inference(True)",fGlobalNS,fLocalNS);
433 PyRunString("graph=_model_to_graph(model,dummyInputs)",fGlobalNS,fLocalNS);
434
435
436 //Extracting the model information in list modelData
437 PyRunString("modelData=[]",fGlobalNS,fLocalNS);
438 // The '_node_get' helper function is used to avoid dependency on onnx submodule
439 // (for the subscript operator of torch._C.Node), as done in https://github.com/pytorch/pytorch/pull/82628
440 PyRunString("def _node_get(node, key):\n"
441 " sel = node.kindOf(key)\n"
442 " return getattr(node, sel)(key)\n",
443 fGlobalNS, fLocalNS);
444 PyRunString("for i in graph[0].nodes():\n"
445 " globals().update(locals())\n"
446 " nodeData={}\n"
447 " nodeData['nodeType']=i.kind()\n"
448 " nodeAttributeNames=[x for x in i.attributeNames()]\n"
449 " nodeAttributes={j: _node_get(i, j) for j in nodeAttributeNames}\n"
450 " nodeData['nodeAttributes']=nodeAttributes\n"
451 " nodeInputs=[x for x in i.inputs()]\n"
452 " nodeInputNames=[x.debugName() for x in nodeInputs]\n"
453 " nodeData['nodeInputs']=nodeInputNames\n"
454 " nodeOutputs=[x for x in i.outputs()]\n"
455 " nodeOutputNames=[x.debugName() for x in nodeOutputs]\n"
456 " nodeData['nodeOutputs']=nodeOutputNames\n"
457 " nodeDType=[x.type().scalarType() for x in nodeOutputs]\n"
458 " nodeData['nodeDType']=nodeDType\n"
459 " modelData.append(nodeData)",
460 fGlobalNS, fLocalNS);
461
462 PyObject* fPModel = PyDict_GetItemString(fLocalNS,"modelData");
464 PyObject *fNode;
465 std::string fNodeType;
466
467 //Adding operators into the RModel object
470 fNodeType = PyStringAsString(PyDict_GetItemString(fNode,"nodeType"));
471
472 // Adding required routines for inference code generation
473 if(fNodeType == "onnx::Gemm"){
474 rmodel.AddBlasRoutines({"Gemm", "Gemv"});
475 }
476 else if(fNodeType == "onnx::Selu" || fNodeType == "onnx::Sigmoid"){
477 rmodel.AddNeededStdLib("cmath");
478 }
479 else if (fNodeType == "onnx::Conv") {
480 rmodel.AddBlasRoutines({"Gemm", "Axpy"});
481 }
482 rmodel.AddOperator(INTERNAL::MakePyTorchNode(fNode));
483 }
484
485
486 // Extracting model weights to add the initialized tensors to the RModel
487 //
488 // The weight shapes and the raw contiguous bytes are extracted on the Python side
489 // (NumPy's `shape` attribute and `tobytes()`), so the parser does not depend on the
490 // NumPy C API and hence is not tied to a specific NumPy ABI/version.
491 PyRunString("weightNames=[k for k in graph[1].keys()]",fGlobalNS,fLocalNS);
492 PyRunString("weightValues=[v.numpy() for v in graph[1].values()]",fGlobalNS,fLocalNS);
493 PyRunString("weightShapes=[list(v.shape) for v in weightValues]",fGlobalNS,fLocalNS);
494 PyRunString("weightBytes=[v.tobytes() for v in weightValues]",fGlobalNS,fLocalNS);
495 PyRunString("weightDTypes=[v.type()[6:-6] for v in graph[1].values()]",fGlobalNS,fLocalNS);
496 PyObject* fPWeightNames = PyDict_GetItemString(fLocalNS,"weightNames");
497 PyObject* fPWeightShapes = PyDict_GetItemString(fLocalNS,"weightShapes");
498 PyObject* fPWeightBytes = PyDict_GetItemString(fLocalNS,"weightBytes");
499 PyObject* fPWeightDTypes = PyDict_GetItemString(fLocalNS,"weightDTypes");
500 std::string fWeightName;
502 std::vector<std::size_t> fWeightShape;
503 std::size_t fWeightSize;
504
508 fWeightSize = 1;
509 fWeightShape.clear();
511 for(Py_ssize_t j=0; j<PyList_Size(fShapeList); ++j){
512 std::size_t dim = (std::size_t)PyLong_AsLong(PyList_GetItem(fShapeList,j));
513 fWeightShape.push_back(dim);
514 fWeightSize*=dim;
515 }
516 switch(fWeightDType){
517 case ETensorType::FLOAT:{
519 std::shared_ptr<void> fData(malloc(fWeightSize * sizeof(float)), free);
520 std::memcpy(fData.get(),fWeightValue,fWeightSize * sizeof(float));
521 rmodel.AddInitializedTensor(fWeightName, ETensorType::FLOAT,fWeightShape,fData);
522 break;
523 }
524 default:
525 throw std::runtime_error("Type error: TMVA SOFIE does not yet supports weights of data type"+ConvertTypeToString(fWeightDType));
526 }
527 }
528
529
530 //Extracting Input tensor info
531 PyRunString("inputs=[x for x in model.graph.inputs()]",fGlobalNS,fLocalNS);
532 PyRunString("inputs=inputs[1:]",fGlobalNS,fLocalNS);
533 PyRunString("inputNames=[x.debugName() for x in inputs]",fGlobalNS,fLocalNS);
534 PyObject* fPInputs= PyDict_GetItemString(fLocalNS,"inputNames");
535 std::string fInputName;
536 std::vector<size_t>fInputShape;
539 fInputName = PyStringAsString(PyList_GetItem(fPInputs,inputIter));
540 fInputShape = inputShapes[inputIter];
542 switch(fInputDType){
543 case(ETensorType::FLOAT): {
544 rmodel.AddInputTensorInfo(fInputName, ETensorType::FLOAT, fInputShape);
545 rmodel.AddInputTensorName(fInputName);
546 break;
547 }
548 default:
549 throw std::runtime_error("Type Error: TMVA SOFIE does not yet support the input tensor data type"+ConvertTypeToString(fInputDType));
550 }
551 }
552
553
554 //Extracting output tensor names
555 PyRunString("outputs=[x for x in graph[0].outputs()]",fGlobalNS,fLocalNS);
556 PyRunString("outputNames=[x.debugName() for x in outputs]",fGlobalNS,fLocalNS);
557 PyObject* fPOutputs= PyDict_GetItemString(fLocalNS,"outputNames");
558 std::vector<std::string> fOutputNames;
560 fOutputNames.push_back(PyStringAsString(PyList_GetItem(fPOutputs,outputIter)));
561 }
562 rmodel.AddOutputTensorNameList(fOutputNames);
563
564 return rmodel;
565}
566
567//////////////////////////////////////////////////////////////////////////////////
568/// \param[in] filepath file location of PyTorch .pt model
569/// \param[in] inputShapes vector of input shape vectors
570/// \return Parsed RModel object
571///
572/// Overloaded Parser function for translating PyTorch .pt model to RModel object.
573/// Function only requires the inputShapes vector as a parameter. Function
574/// builds the vector of Data-types for the input tensors using Float as default,
575/// Function calls the `Parse()` function with the vector of data-types included,
576/// subsequently returning the parsed RModel object.
577RModel Parse(std::string filepath,std::vector<std::vector<size_t>> inputShapes){
578 std::vector<ETensorType> dtype(inputShapes.size(),ETensorType::FLOAT);
580}
581
582} // namespace TMVA::Experimental::SOFIE::PyTorch
int Py_ssize_t
Definition CPyCppyy.h:215
#define PyBytes_AsString
Definition CPyCppyy.h:64
int main()
Definition Prototype.cxx:12
_object PyObject
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
#define malloc
Definition civetweb.c:1575
const_iterator end() const
Basic string class.
Definition TString.h:138
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2385
std::unique_ptr< ROperator > MakePyTorchGemm(PyObject *fNode)
Prepares a ROperator_Gemm object.
std::unique_ptr< ROperator > MakePyTorchNode(PyObject *fNode)
Prepares equivalent ROperator with respect to PyTorch ONNX node.
std::unique_ptr< ROperator > MakePyTorchConv(PyObject *fNode)
Prepares a ROperator_Conv object.
std::unique_ptr< ROperator > MakePyTorchSigmoid(PyObject *fNode)
Prepares a ROperator_Sigmoid object.
std::unique_ptr< ROperator > MakePyTorchSelu(PyObject *fNode)
Prepares a ROperator_Selu object.
std::unique_ptr< ROperator > MakePyTorchRelu(PyObject *fNode)
Prepares a ROperator_Relu object.
std::unordered_map< std::string, std::unique_ptr< ROperator >(*)(PyObject *fNode)> PyTorchMethodMap
std::unique_ptr< ROperator > MakePyTorchTranspose(PyObject *fNode)
Prepares a ROperator_Transpose object.
RModel Parse(std::string filepath, std::vector< std::vector< size_t > > inputShapes, std::vector< ETensorType > dtype)
Parser function for translating PyTorch .pt model into a RModel object.
std::string ConvertTypeToString(ETensorType type)
ETensorType ConvertStringToType(std::string type)