Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_ONNX.cxx
Go to the documentation of this file.
2#include "onnx_proto3.pb.h"
3
4#include <stdexcept>
5#include <string>
6#include <memory>
7#include <cassert>
8#include <iostream>
9#include <unordered_map>
10#include <functional>
11#include "TMVA/SOFIE_common.hxx"
12
13namespace TMVA {
14namespace Experimental {
15namespace SOFIE {
16
17// Declaration of operators
18// Unary operators
23// Binary operators
29// Nary operators
34// Reduce operators
38// Others
63// Decalaration of fused operators
67
68// Definition of RModelParser_ONNX::OperatorsMap
70 // Registered operators
71 std::unordered_map<std::string, ParserFuncSignature> fOperatorsMap;
72};
73
74// Constructor of the parser
75RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_unique<OperatorsMapImpl>()) {
76 // Register operators
77 // Unary operators
79 RegisterOperator("Reciprocal", ParseReciprocal);
82 // Binary operators
88 // Nary operators
93 // Reduce operators
94 RegisterOperator("ReduceMean", ParseReduceMean);
95 RegisterOperator("ReduceSumsquare", ParseReduceSumsquare);
96 RegisterOperator("ReduceProd", ParseReduceProd);
97 // Others
98 RegisterOperator("BatchNormalization", ParseBatchNormalization);
100 RegisterOperator("Concat", ParseConcat);
102 RegisterOperator("ConvTranspose", ParseConvTranspose);
105 RegisterOperator("Identity", ParseIdentity);
106 RegisterOperator("LeakyRelu", ParseLeakyRelu);
108 RegisterOperator("AveragePool", ParsePool);
109 RegisterOperator("GlobalAveragePool", ParsePool);
110 RegisterOperator("MaxPool", ParsePool);
112 RegisterOperator("Reshape", ParseReshape);
113 RegisterOperator("Flatten", ParseReshape);
114 RegisterOperator("Squeeze", ParseReshape);
115 RegisterOperator("Unsqueeze", ParseReshape);
119 RegisterOperator("Sigmoid", ParseSigmoid);
121 RegisterOperator("Softmax", ParseSoftmax);
123 RegisterOperator("Softmax", ParseSoftmax);
125 RegisterOperator("Transpose", ParseTranspose);
126 RegisterOperator("LayerNormalization", ParseLayerNormalization);
127 RegisterOperator("Expand", ParseExpand);
128 RegisterOperator("Gather", ParseGather);
129}
130
131// Destructor of the parser
133
135{
136 fOperatorsMapImpl->fOperatorsMap[name] = func;
137}
138
140{
141 return fOperatorsMapImpl->fOperatorsMap.find(name) != fOperatorsMapImpl->fOperatorsMap.end();
142}
143
145{
146 std::vector<std::string> ops;
147 ops.reserve(fOperatorsMapImpl->fOperatorsMap.size());
148 for (auto &it : fOperatorsMapImpl->fOperatorsMap) {
149 ops.emplace_back(it.first);
150 }
151 return ops;
152}
153
155{
157}
158
160{
162}
163
165{
167}
168
169// Parse an operator
170std::unique_ptr<ROperator>
171RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphproto, const std::vector<size_t> &nodes)
172{
173 int idx = (nodes.size() > i) ? nodes[i] : (int)i;
174 const auto &nodeproto = graphproto.node(idx);
175 const std::string op_type = nodeproto.op_type();
176 if (fVerbose)
177 std::cout << "Parsing an operator " << op_type << std::endl;
178
179
180 if (op_type == "MatMul") {
181 // Fuse MatMul and Add
182 int idx2 = (nodes.size() > i + 1) ? nodes[i + 1] : (int)i + 1;
183 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") {
184 return ParseFuseMatMulAdd(*this, graphproto.node(idx), graphproto.node(idx2));
185 }
186 } else if (nodeproto.op_type() == "Conv" || nodeproto.op_type() == "ConvTranspose") {
187 // Fuse Conv or ConvTranspose without bias and Add
188 int j = (nodes.size() > i + 1) ? nodes[i + 1] : (int)i + 1;
189 if (j < graphproto.node_size() && graphproto.node(j).op_type() == "Add") {
190 if (nodeproto.op_type() == "Conv") {
191 return ParseFuseConvAdd(*this, graphproto.node(idx), graphproto.node(j));
192 } else {
193 return ParseFuseConvTransposeAdd(*this, graphproto.node(idx), graphproto.node(j));
194 }
195 }
196 }
197
198 // skip then the following Add
199 if (idx > 0 && op_type == "Add") {
200 int idx0 = (nodes.size() > i) ? nodes[i - 1] : (int)i - 1;
201 if (graphproto.node(idx0).op_type() == "MatMul")
202 return nullptr;
203 else if (graphproto.node(idx0).op_type() == "ConvTranspose")
204 return nullptr;
205 }
206
207 auto it = fOperatorsMapImpl->fOperatorsMap.find(op_type);
208 if (it == fOperatorsMapImpl->fOperatorsMap.end()) {
209 throw std::runtime_error("TMVA::SOFIE Operator type " + op_type + " is not yet supported");
210 }
211 if (fVerbose) {
212 std::cout << "\tCreating operator " << op_type << std::endl;
213 }
214 return it->second(*this, nodeproto);
215}
216
217// Parse a model
218RModel RModelParser_ONNX::Parse(std::string filename, bool verbose)
219{
220 fVerbose = verbose;
221 char sep = '/';
222#ifdef _WIN32
223 sep = '\\';
224#endif
225 size_t isep = filename.rfind(sep, filename.length());
226 std::string filename_nodir = filename;
227 if (isep != std::string::npos) {
228 filename_nodir = (filename.substr(isep + 1, filename.length() - isep));
229 }
230
231 std::time_t ttime = std::time(0);
232 std::tm *gmt_time = std::gmtime(&ttime);
233 std::string parsetime(std::asctime(gmt_time));
234
235 GOOGLE_PROTOBUF_VERIFY_VERSION;
236 // model I/O
237 onnx::ModelProto model;
238 RModel rmodel(filename_nodir, parsetime);
239
240 fTensorTypeMap.clear();
241
242 std::fstream input(filename, std::ios::in | std::ios::binary);
243 if (!model.ParseFromIstream(&input)) {
244 throw std::runtime_error("TMVA::SOFIE - Failed to parse onnx file " + filename);
245 }
246
247 const onnx::GraphProto &graph = model.graph(); // not a memory leak. model freed automatically at the end.
248 google::protobuf::ShutdownProtobufLibrary();
249
250 // ONNX version is ir_version() - model_version() returns 0
251 if (fVerbose) {
252 std::cout << "ONNX Version " << model.ir_version() << std::endl;
253 }
254
255 std::unordered_set<std::string> initializer_names;
256 for (int i = 0; i < graph.initializer_size(); i++) {
257 initializer_names.insert(graph.initializer(i).name());
258 }
259
260 if (verbose)
261 std::cout << "Parsing model inputs...." << std::endl;
262 /// Loop on model inputs
263 for (int i = 0; i < graph.input_size(); i++) {
264 RegisterTensorType(graph.input(i).name(),
265 static_cast<ETensorType>(graph.input(i).type().tensor_type().elem_type()));
266
267 if (verbose)
268 std::cout << "\tgraph input " << i << " name " << graph.input(i).name() << " type "
269 << graph.input(i).type().tensor_type().elem_type() << std::endl;
270
271 if (initializer_names.find(graph.input(i).name()) != initializer_names.end())
272 continue;
273
274 // input data node is not a weight node (has no initializer)
275 const onnx::ValueInfoProto &valueinfoproto = graph.input(i);
276 std::string input_name = valueinfoproto.name();
277
278 ETensorType type = static_cast<ETensorType>(valueinfoproto.type().tensor_type().elem_type());
280 throw std::runtime_error("TMVA::SOFIE Data type in input tensor " + input_name + " not supported!\n");
281 }
282
283 std::vector<Dim> fShape;
284 bool existParam = false;
285 if (!valueinfoproto.type().tensor_type().has_shape())
286 throw std::runtime_error("TMVA::SOFIE datanode with no shape restrictions is not supported yet");
287 for (int j = 0; j < valueinfoproto.type().tensor_type().shape().dim_size(); j++) {
288 Dim dim;
289 if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
290 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
291 dim.dim = valueinfoproto.type().tensor_type().shape().dim(j).dim_value();
292 } else if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
293 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
294 dim.isParam = true;
295 existParam = true;
296 dim.param = valueinfoproto.type().tensor_type().shape().dim(j).dim_param();
297 } else {
298 throw std::runtime_error("TMVA::SOFIE ONNX file error: Valueinfoproto " + input_name +
299 " has neither dim_value nor dim_param! \n");
300 }
301 fShape.push_back(dim);
302 }
303 if (valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
304 Dim dim;
305 dim.dim = 1;
306 fShape.push_back(dim);
307 } // in case this TensorShapeProto has no dimension message: ONNX IR defines this to be a scalar
308
309 if (!existParam) {
310 std::vector<size_t> fShape_sizet;
311 for (auto &j : fShape) {
312 fShape_sizet.push_back(j.dim);
313 }
314
315 rmodel.AddInputTensorInfo(input_name, type, fShape_sizet);
316 } else {
317 rmodel.AddInputTensorInfo(input_name, type, fShape);
318 }
319 rmodel.AddInputTensorName(input_name); // store also names in given order
320 }
321
322 std::map<std::string, int> allInitializedTensors;
323
324 if (verbose)
325 std::cout << "\nParsing graph initializer list and fill model initialized tensors" << std::endl;
326
327 for (int i = 0; i < graph.initializer_size(); i++) {
328 onnx::TensorProto *tensorproto = const_cast<onnx::TensorProto *>(&graph.initializer(i));
329 std::vector<std::size_t> shape;
330 std::size_t fLength = 1;
331 for (int j = 0; j < tensorproto->dims_size(); j++) {
332 shape.push_back(tensorproto->dims(j));
333 fLength *= tensorproto->dims(j);
334 }
335 // in case of scalars keep an empty shape but with length =1
336
337 std::string input_name = graph.initializer(i).name();
338
339 if (verbose)
340 std::cout << "\t initializer " << i << " name " << input_name << " type " << graph.initializer(i).data_type()
341 << std::endl;
342
343 switch (static_cast<ETensorType>(graph.initializer(i).data_type())) {
344 case ETensorType::FLOAT: {
345 std::shared_ptr<void> data(malloc(fLength * sizeof(float)), free);
346
347 if (tensorproto->raw_data().empty() == false) {
348 auto raw_data_ptr = reinterpret_cast<float *>(const_cast<char *>(tensorproto->raw_data().c_str()));
349 std::memcpy(data.get(), raw_data_ptr, fLength * sizeof(float));
350 } else {
351 tensorproto->mutable_float_data()->ExtractSubrange(0, tensorproto->float_data_size(),
352 static_cast<float *>(data.get()));
353 }
354
355 if (verbose) std::cout << "add FLOAT initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
356 rmodel.AddInitializedTensor(input_name, ETensorType::FLOAT, shape, data);
357 allInitializedTensors[input_name] = i;
358 break;
359 }
360 case ETensorType::INT64: {
361 std::shared_ptr<void> data(malloc(fLength * sizeof(int64_t)), free);
362
363 if (tensorproto->raw_data().empty() == false) {
364 auto raw_data_ptr = reinterpret_cast<int64_t *>(const_cast<char *>(tensorproto->raw_data().c_str()));
365 std::memcpy(data.get(), raw_data_ptr, fLength * sizeof(int64_t));
366 } else {
367 tensorproto->mutable_int64_data()->ExtractSubrange(0, tensorproto->int64_data_size(),
368 static_cast<int64_t *>(data.get()));
369 }
370
371 if (verbose) std::cout << "add INT64 initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
372 rmodel.AddInitializedTensor(input_name, ETensorType::INT64, shape, data);
373 allInitializedTensors[input_name] = i;
374 break;
375 }
376 default:
377 throw std::runtime_error("Data type in weight tensor " + graph.initializer(i).name() + " not supported!\n");
378 }
379 }
380
381 // Initial operator order
382 if (verbose) {
383 std::cout << "\nGraph operator list (ONNX order)\n";
384 for (int i = 0; i < graph.node_size(); i++) {
385 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).input_size()
386 << " inputs : {";
387 for (int j = 0; j < graph.node(i).input_size(); j++) {
388 std::cout << graph.node(i).input(j);
389 if (j < graph.node(i).input_size() - 1)
390 std::cout << ", ";
391 }
392 std::cout << " }" << std::endl;
393 }
394 }
395
396 // make order of nodes:
397 if (verbose)
398 std::cout << "\nRe-Order graph operator list\n";
399 std::vector<size_t> nodesOrder;
400 nodesOrder.reserve(graph.node_size());
401 std::vector<bool> foundNodes(graph.node_size());
402 // loop at graph inputs
403 std::map<std::string, int> allInputs;
404 for (int i = 0; i < graph.input_size(); i++) {
405 allInputs[graph.input(i).name()] = -1;
406 }
407 do {
408 auto psize = nodesOrder.size();
409 for (int i = 0; i < graph.node_size(); i++) {
410 if (foundNodes[i])
411 continue;
412 // check if all input exists add to list
413 bool existInputs = true;
414 int input_size = graph.node(i).input_size();
415 // special case for Reshape where shape is input and not a weight tensor
416 for (int j = 0; j < input_size; j++) {
417 std::string name = graph.node(i).input(j);
418 // skip empty names
419 if (!name.empty()) {
420 existInputs &= (allInputs.find(name) != allInputs.end() ||
421 allInitializedTensors.find(name) != allInitializedTensors.end());
422 if (fVerbose) {
423 std::cout << graph.node(i).op_type() << " input " << name << " "
424 << bool(allInputs.find(name) != allInputs.end()) << " " <<
425 bool(allInitializedTensors.find(name) != allInitializedTensors.end()) <<
426 existInputs << std::endl;
427 }
428 }
429 }
430 if (!existInputs) {
431 if (fVerbose) {
432 std::cout << "skip op " << graph.node(i).op_type() << " inputs are ";
433 for (int j = 0; j < input_size; j++) {
434 std::cout << graph.node(i).input(j) << " ";
435 }
436 std::cout << std::endl;
437 }
438 continue;
439 }
440 if (verbose)
441 std::cout << "\tadd node " << graph.node(i).op_type() << " order " << i << std::endl;
442
443 nodesOrder.push_back(i);
444 foundNodes[i] = true;
445 // register the outputs
446 for (int j = 0; j < graph.node(i).output_size(); j++) {
447 allInputs[graph.node(i).output(j)] = i;
448 }
449 }
450 // no increment in nodes - something wrong
451 if (nodesOrder.size() == psize) {
452 throw std::runtime_error("TMVA::SOFIE - cannot find a new node ");
453 }
454 } while ((int)nodesOrder.size() < graph.node_size());
455
456 // scan operators for orders
457 if (verbose) {
458 std::cout << "\nGraph operator list (re-ordered)\n";
459 for (int k = 0; k < graph.node_size(); k++) {
460 int i = nodesOrder[k];
461 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).input_size()
462 << " inputs : {";
463 for (int j = 0; j < graph.node(i).input_size(); j++) {
464 std::cout << graph.node(i).input(j);
465 if (j < graph.node(i).input_size() - 1)
466 std::cout << ", ";
467 }
468 std::cout << " }" << std::endl;
469 }
470 }
471
472 // fill model with operators
473 if (verbose) {
474 std::cout << "Fill RModel with operators...\n";
475 }
476 for (int i = 0; i < graph.node_size(); i++) {
477 std::string op_type = graph.node(nodesOrder[i]).op_type();
478
479 if (verbose) {
480 std::cout << "\t" << i << " " << nodesOrder[i] << " parsing operator " << op_type << std::endl;
481 }
482
483 std::unique_ptr<ROperator> op = ParseOperator(i, graph, nodesOrder);
484 if (!op) {
485 if (verbose) {
486 std::cout << "\t\tskipping operator since it is fused with previous one" << std::endl;
487 }
488 // for skipping the fused nodes like Add after MatMul
489 continue;
490 }
491 rmodel.AddOperator(std::move(op));
492 }
493
494 std::vector<std::string> outputnames;
495 if (verbose)
496 std::cout << "\nParsing Graph output list\n";
497 for (int i = 0; i < graph.output_size(); i++) {
498 if (verbose)
499 std::cout << "\toutput " << i << " name " << graph.output(i).name() << std::endl;
500 outputnames.push_back(graph.output(i).name());
501 }
502 rmodel.AddOutputTensorNameList(outputnames);
503
504 return rmodel;
505}
506
507} // namespace SOFIE
508} // namespace Experimental
509} // namespace TMVA
Py_ssize_t * fShape
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:110
#define free
Definition civetweb.c:1539
#define malloc
Definition civetweb.c:1536
void RegisterOperator(const std::string &name, ParserFuncSignature func)
bool IsRegisteredOperator(const std::string &name)
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
void AddOutputTensorNameList(std::vector< std::string > outputtensornames)
Definition RModel.cxx:169
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector< Dim > shape)
Definition RModel.cxx:108
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:144
void AddInputTensorName(std::string name)
Definition RModel.cxx:127
void AddOperator(std::unique_ptr< ROperator > op, int order_execution=-1)
Definition RModel.cxx:131
std::string Clean_name(std::string input_tensor_name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
ParserFuncSignature ParsePool
Definition ParsePool.cxx:9
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseMax
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseSlice
Definition ParseSlice.cxx:9
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseShape
Definition ParseShape.cxx:9
ParserFuncSignature ParseGRU
Definition ParseGRU.cxx:9
ParserFuncSignature ParseSub
ParserFuncSignature ParseReduceSumsquare
ParserFuncSignature ParseAdd
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
Definition ParseRNN.cxx:9
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseLSTM
Definition ParseLSTM.cxx:9
ParserFuncSignature ParseCast
Definition ParseCast.cxx:9
ParserFuncSignature ParseReciprocal
std::string ConvertShapeToString(std::vector< size_t > shape)
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseMean
ParserFuncSignature ParseSelu
Definition ParseSelu.cxx:9
ParserFuncSignature ParseSum
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
Definition ParseRelu.cxx:9
ParserFuncSignature ParseConv
Definition ParseConv.cxx:9
ParserFuncSignature ParseGemm
Definition ParseGemm.cxx:9
ParserFuncSignature ParseMul
ParserFuncSignature ParsePow
ParserFuncSignature ParseTanh
Definition ParseTanh.cxx:9
create variable transformations
Definition graph.py:1
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap