Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_ONNX.cxx
Go to the documentation of this file.
1#include "Byteswap.h"
3#include "onnx_proto3.pb.h"
4
5#include <stdexcept>
6#include <string>
7#include <memory>
8#include <cassert>
9#include <iostream>
10#include <unordered_map>
11#include <functional>
12#include "TMVA/SOFIE_common.hxx"
13
14namespace TMVA {
15namespace Experimental {
16namespace SOFIE {
17
18// Declaration of operators
19// Unary operators
29// Binary operators
36// Nary operators
41//Comparision Operators
47//Is Operators
51// Reduce operators
56// Others
98// Declaration of fused operators
104
105// Definition of RModelParser_ONNX::OperatorsMap
107 // Registered operators
108 std::unordered_map<std::string, ParserFuncSignature> fOperatorsMap;
109};
110
111// helper function to get initialized tensor data
112template<typename T>
114};
115// trait function to extract data from TensorProto
116template<>
117struct ExtractDataFromTP<float> {
118 static void Copy(onnx::TensorProto * tensor, void * data) {
119 tensor->mutable_float_data()->ExtractSubrange(0, tensor->float_data_size(),
120 static_cast<float *>(data));
121 }
122};
123template<>
125 static void Copy(onnx::TensorProto * tensor, void * data) {
126 tensor->mutable_double_data()->ExtractSubrange(0, tensor->double_data_size(),
127 static_cast<double *>(data));
128 }
129};
130template<>
131struct ExtractDataFromTP<int32_t> {
132 static void Copy(onnx::TensorProto * tensor, void * data) {
133 tensor->mutable_int32_data()->ExtractSubrange(0, tensor->int32_data_size(),
134 static_cast<int32_t *>(data));
135 }
136};
137template<>
138struct ExtractDataFromTP<int64_t> {
139 static void Copy(onnx::TensorProto * tensor, void * data) {
140 tensor->mutable_int64_data()->ExtractSubrange(0, tensor->int64_data_size(),
141 static_cast<int64_t *>(data));
142 }
143};
144template<typename T>
145std::shared_ptr<void> GetInitializedTensorData(onnx::TensorProto * tensorproto, size_t length) {
146 std::shared_ptr<void> data(malloc(length * sizeof(T)), free);
147
148 if (!tensorproto->raw_data().empty()) {
149#ifdef R__BYTESWAP
150 std::memcpy(data.get(), tensorproto->raw_data().c_str(), length * sizeof(T));
151#else
152 for (std::size_t k = 0; k < length; ++k)
153 (reinterpret_cast<typename RByteSwap<sizeof(T)>::value_type *>(data.get()))[k] =
154 RByteSwap<sizeof(T)>::bswap((reinterpret_cast<const typename RByteSwap<sizeof(T)>::value_type *>(tensorproto->raw_data().c_str()))[k]);
155#endif
156 } else {
158 }
159 return data;
160}
161
162// Constructor of the parser
163RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_unique<OperatorsMapImpl>()) {
164 // Register operators
165 // Unary operators
167 RegisterOperator("Reciprocal", ParseReciprocal);
174 RegisterOperator("Softplus", ParseSoftplus);
175 // Binary operators
182 // Nary operators
187 //Comparision Operators
188 RegisterOperator("Equal", ParseEq);
190 RegisterOperator("LessOrEqual", ParseLessEq);
191 RegisterOperator("Greater", ParseGreater);
192 RegisterOperator("GreaterOrEqual", ParseGreaterEq);
193 // Is If operators
197 // Reduce operators
198 RegisterOperator("ReduceMean", ParseReduceMean);
199 RegisterOperator("ReduceSum", ParseReduceSum);
200 RegisterOperator("ReduceSumSquare", ParseReduceSumSquare);
201 RegisterOperator("ReduceProd", ParseReduceProd);
202 // Others
203 RegisterOperator("BatchNormalization", ParseBatchNormalization);
204 RegisterOperator("Constant", ParseConstant);
205 RegisterOperator("ConstantOfShape", ParseConstant);
207 RegisterOperator("Concat", ParseConcat);
209 RegisterOperator("ConvTranspose", ParseConvTranspose);
212 RegisterOperator("Identity", ParseIdentity);
213 RegisterOperator("LeakyRelu", ParseLeakyRelu);
215 RegisterOperator("AveragePool", ParsePool);
216 RegisterOperator("GlobalAveragePool", ParsePool);
217 RegisterOperator("MaxPool", ParsePool);
219 RegisterOperator("Reshape", ParseReshape);
220 RegisterOperator("Flatten", ParseReshape);
221 RegisterOperator("Squeeze", ParseReshape);
222 RegisterOperator("Unsqueeze", ParseReshape);
226 RegisterOperator("Sigmoid", ParseSigmoid);
228 RegisterOperator("Softmax", ParseSoftmax);
229 RegisterOperator("LogSoftmax", ParseSoftmax);
231 RegisterOperator("Transpose", ParseTranspose);
232 RegisterOperator("MatMul", ParseMatMul);
233 RegisterOperator("LayerNormalization", ParseLayerNormalization);
234 RegisterOperator("Expand", ParseExpand);
235 RegisterOperator("Gather", ParseGather);
236 RegisterOperator("GatherND", ParseGatherND);
239 RegisterOperator("EyeLike", ParseEyeLike);
247 RegisterOperator("Einsum", ParseEinsum);
248 RegisterOperator("RandomNormal", ParseRandom);
249 RegisterOperator("RandomNormalLike", ParseRandom);
250 RegisterOperator("RandomUniform", ParseRandom);
251 RegisterOperator("RandomUniformLike", ParseRandom);
252 RegisterOperator("ScatterElements", ParseScatterElements);
253 RegisterOperator("NonZero", ParseNonZero);
254}
255
256// Destructor of the parser
258
260{
261 fOperatorsMapImpl->fOperatorsMap[name] = func;
262}
263
265{
266 return fOperatorsMapImpl->fOperatorsMap.find(name) != fOperatorsMapImpl->fOperatorsMap.end();
267}
268
270{
271 std::vector<std::string> ops;
272 ops.reserve(fOperatorsMapImpl->fOperatorsMap.size());
273 for (auto &it : fOperatorsMapImpl->fOperatorsMap) {
274 ops.emplace_back(it.first);
275 }
276 // return sorted list in alphabetical order
277 std::sort(ops.begin(), ops.end());
278 return ops;
279}
280
285
287{
289}
290
295
296// Parse an operator
297std::unique_ptr<ROperator>
298RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphproto, const std::vector<size_t> &nodes, const std::vector<int> & children)
299{
300 if (i >= nodes.size())
301 throw std::runtime_error("TMVA::SOFIE - Error in parsing ordered operators " + std::to_string(i) + " is >= " + std::to_string(nodes.size()));
302 int idx = nodes[i];
303 const auto &nodeproto = graphproto.node(idx);
304 const std::string op_type = nodeproto.op_type();
305 if (fVerbose)
306 std::cout << "Parsing operator " << op_type << std::endl;
307
308 // skip already fused operators
309 if (fFusedOperators[idx]) return nullptr;
310
311 // try to fuse with following operator in case it is not last one
312 if (children.size() == 1) {
313 int idx2 = children.front();
314 if (op_type == "MatMul") {
315 // Fuse MatMul and Add
316 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") {
317 fFusedOperators[idx2] = true;
318 return ParseFuseMatMulAdd(*this, graphproto.node(idx), graphproto.node(idx2));
319 }
320 else {
321 return ParseMatMul(*this, graphproto.node(idx));
322 }
323 } else if (nodeproto.op_type() == "Conv" || nodeproto.op_type() == "ConvTranspose") {
324 // Fuse Conv or ConvTranspose without bias and Add
325 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") {
326 if (nodeproto.op_type() == "Conv") {
327 fFusedOperators[idx2] = true;
328 return ParseFuseConvAdd(*this, graphproto.node(idx), graphproto.node(idx2));
329 } else {
330 fFusedOperators[idx2] = true;
331 return ParseFuseConvTransposeAdd(*this, graphproto.node(idx), graphproto.node(idx2));
332 }
333 }
334 } else if (nodeproto.op_type() == "Gemm") {
335 // Fuse Gemm with activation operators
336 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") {
337 fFusedOperators[idx2] = true;
338 return ParseFuseGemmRelu(*this, graphproto.node(idx), graphproto.node(idx2));
339 }
340 } else if (nodeproto.op_type() == "BatchNormalization") {
341 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") {
342 fFusedOperators[idx2] = true;
343 return ParseFuseBatchnormRelu(*this, graphproto.node(idx), graphproto.node(idx2));
344 }
345 }
346 }
347
348
349
350 auto it = fOperatorsMapImpl->fOperatorsMap.find(op_type);
351 if (it == fOperatorsMapImpl->fOperatorsMap.end()) {
352 std::cout << "operator " << op_type << " is not supported" << std::endl;
353 throw std::runtime_error("TMVA::SOFIE Operator type " + op_type + " is not yet supported");
354 }
355 if (fVerbose) {
356 std::cout << "\tCreating operator " << op_type << std::endl;
357 }
358 return it->second(*this, nodeproto);
359}
360
361// Parse a model
362RModel RModelParser_ONNX::Parse(std::string const &filename, bool verbose)
363{
364 fVerbose = verbose;
365
366 fTensorTypeMap.clear();
367
368 auto model = LoadModel(filename);
369 if (!model)
370 throw std::runtime_error("TMVA::SOFIE - Failed to load onnx file " + filename);
371
372 const onnx::GraphProto &graph = model->graph(); // not a memory leak. model freed automatically at the end.
373
374
375 std::time_t ttime = std::time(0);
376 std::tm *gmt_time = std::gmtime(&ttime);
377 std::string parsetime(std::asctime(gmt_time));
378
379 // get name of model (filename without directory name)
380 char sep = '/';
381#ifdef _WIN32
382 sep = '\\';
383#endif
384 size_t isep = filename.rfind(sep, filename.length());
385 std::string filename_nodir = filename;
386 if (isep != std::string::npos) {
387 filename_nodir = (filename.substr(isep + 1, filename.length() - isep));
388 }
389
392 return rmodel;
393}
394
395RModel RModelParser_ONNX::Parse(std::istream &input, std::string const &name, bool verbose)
396{
397 fVerbose = verbose;
398
399 fTensorTypeMap.clear();
400
401 auto model = LoadModel(input);
402 if (!model)
403 throw std::runtime_error("TMVA::SOFIE - Failed to parse ONNX model from input stream");
404
405 const onnx::GraphProto &graph = model->graph(); // not a memory leak. model freed automatically at the end.
406
407 std::time_t ttime = std::time(0);
408 std::tm *gmt_time = std::gmtime(&ttime);
409 std::string parsetime(std::asctime(gmt_time));
410
412 ParseONNXGraph(rmodel, graph, name);
413 return rmodel;
414}
415
416std::unique_ptr<onnx::ModelProto> RModelParser_ONNX::LoadModel(const std::string &filename) {
417 std::fstream input(filename, std::ios::in | std::ios::binary);
418 if (!input) {
419 std::cerr << "TMVA::SOFIE - Failed to open onnx file " << filename << std::endl;
420 return {};
421 }
422
423 return LoadModel(input);
424}
425
426std::unique_ptr<onnx::ModelProto> RModelParser_ONNX::LoadModel(std::istream &input) {
428 auto model = std::make_unique<onnx::ModelProto>();
429
430 if (!model->ParseFromIstream(&input)) {
431 std::cerr << "TMVA::SOFIE - Failed to parse ONNX model from input stream" << std::endl;
432 return {};
433 }
434
435 // ONNX version is ir_version() - model_version() returns 0
436 if (fVerbose) {
437 std::cout << "ONNX Version " << model->ir_version() << std::endl;
438 }
439 google::protobuf::ShutdownProtobufLibrary();
440 return model;
441
442}
443
444void RModelParser_ONNX::CheckGraph(const onnx::GraphProto & graph, int & level, std::map<std::string, int> & missingOperators) {
445 if (fVerbose)
446 std::cout << "\n" << graph.name() << " Graph operator list\n";
447 for (int i = 0; i < graph.node_size(); i++) {
448 const auto & node = graph.node(i);
449 const std::string opType = node.op_type();
450 if (fVerbose) {
451 std::cout << "\tOperator " << i << " : " << opType << " (" << node.name() << "), " << graph.node(i).input_size()
452 << " inputs : {";
453 for (int j = 0; j < graph.node(i).input_size(); j++) {
454 std::cout << graph.node(i).input(j);
455 if (j < graph.node(i).input_size() - 1)
456 std::cout << ", ";
457 }
458 std::cout << " }" << std::endl;
459 }
460 // check if operator exists
462 missingOperators[opType] = level;
463 // see if sub-graph exists as node attributes
464 for (int j = 0; j < node.attribute_size(); j++) {
465 const auto & attribute = node.attribute(j);
466 if (attribute.has_g()) {
467 const auto & subGraph = attribute.g();
468 level += 1;
470 }
471 }
472 }
473}
474
475bool RModelParser_ONNX::CheckModel(std::string filename, bool verbose) {
476
477 fVerbose = verbose;
478 auto model = LoadModel(filename);
479 if (!model) return false;
480
481 const onnx::GraphProto &graph = model->graph();
482 // Initial operator order
483 if (fVerbose)
484 std::cout << "\nModel operator list " << model->producer_name() << "\n";
485
486 std::map<std::string, int> missingOperators;
487 int level = 1;
488 CheckGraph(graph, level, missingOperators);
489
490 if (!missingOperators.empty()) {
491 std::cout << "List of missing operators for model loaded from file " << filename << std::endl;
492 for (auto & op : missingOperators) {
493 std::cout << op.first << " " << op.second << std::endl;
494 }
495 return false;
496 }
497 std::cout << "All operators in the loaded model are supported!\n";
498 return true;
499}
500
501void RModelParser_ONNX::ParseONNXGraph(RModel & rmodel, const onnx::GraphProto & graph, std::string graphName)
502{
503 bool verbose = fVerbose;
504
505 if (graphName.empty())
506 graphName = graph.name();
507
508 if (verbose)
509 std::cout << "\nParsing Graph - " << graphName << std::endl;
510
511 std::unordered_set<std::string> initializer_names;
512 for (int i = 0; i < graph.initializer_size(); i++) {
513 initializer_names.insert(graph.initializer(i).name());
514 }
515
516 if (verbose)
517 std::cout << "Parsing model inputs...." << std::endl;
518 /// Loop on model inputs
519 for (int i = 0; i < graph.input_size(); i++) {
520 RegisterTensorType(graph.input(i).name(),
521 static_cast<ETensorType>(graph.input(i).type().tensor_type().elem_type()));
522
523 if (verbose)
524 std::cout << "\tgraph input " << i << " name " << graph.input(i).name() << " type "
525 << graph.input(i).type().tensor_type().elem_type() << std::endl;
526
527 if (initializer_names.find(graph.input(i).name()) != initializer_names.end())
528 continue;
529
530 // input data node is not a weight node (has no initializer)
531 const onnx::ValueInfoProto &valueinfoproto = graph.input(i);
532 std::string input_name = valueinfoproto.name();
533
534 ETensorType type = static_cast<ETensorType>(valueinfoproto.type().tensor_type().elem_type());
535
536 std::vector<Dim> fShape;
537 bool existParam = false;
538 if (!valueinfoproto.type().tensor_type().has_shape())
539 throw std::runtime_error("TMVA::SOFIE data node with no shape restrictions is not supported yet");
540 for (int j = 0; j < valueinfoproto.type().tensor_type().shape().dim_size(); j++) {
541 Dim dim;
542 if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
543 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
544 int dim_value = valueinfoproto.type().tensor_type().shape().dim(j).dim_value();
545 dim.dim = dim_value;
546 // case input dim is -1 - set a parametric shape
547 if (dim_value < 0) {
548 dim.isParam = true;
549 existParam = true;
550 dim.param = UTILITY::Clean_name(input_name) + "_size";
551 }
552 } else if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
553 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
554 dim.isParam = true;
555 existParam = true;
556 dim.param = valueinfoproto.type().tensor_type().shape().dim(j).dim_param();
557 } else {
558 throw std::runtime_error("TMVA::SOFIE ONNX file error: Valueinfoproto " + input_name +
559 " has neither dim_value nor dim_param! \n");
560 }
561 fShape.push_back(dim);
562 }
563 if (valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
564 Dim dim;
565 dim.dim = 1;
566 fShape.push_back(dim);
567 } // in case this TensorShapeProto has no dimension message: ONNX IR defines this to be a scalar
568
569 if (!existParam) {
570 std::vector<size_t> fShape_sizet;
571 for (auto &j : fShape) {
572 fShape_sizet.push_back(j.dim);
573 }
574
575 rmodel.AddInputTensorInfo(input_name, type, fShape_sizet);
576 } else {
577 rmodel.AddInputTensorInfo(input_name, type, fShape);
578 }
579 rmodel.AddInputTensorName(input_name); // store also names in given order
580 }
581
582 std::map<std::string, int> allInitializedTensors;
583
584 if (verbose)
585 std::cout << "\nParsing graph initializer list and fill model initialized tensors" << std::endl;
586
587 for (int i = 0; i < graph.initializer_size(); i++) {
588 onnx::TensorProto *tensorproto = const_cast<onnx::TensorProto *>(&graph.initializer(i));
589 std::vector<std::size_t> shape;
590 std::size_t fLength = 1;
591 for (int j = 0; j < tensorproto->dims_size(); j++) {
592 shape.push_back(tensorproto->dims(j));
593 fLength *= tensorproto->dims(j);
594 }
595 // in case of scalars keep an empty shape but with length =1
596
597 std::string input_name = graph.initializer(i).name();
598
599 if (verbose)
600 std::cout << "\t initializer " << i << " name " << input_name << " type " << graph.initializer(i).data_type()
601 << std::endl;
602
603 // register also the initialized tensors
604 auto tensor_type = static_cast<ETensorType>(graph.initializer(i).data_type());
606
607 switch (tensor_type) {
608 case ETensorType::FLOAT: {
609 std::shared_ptr<void> data = GetInitializedTensorData<float>(tensorproto, fLength);
610 if (verbose) std::cout << "add FLOAT initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
611 rmodel.AddInitializedTensor(input_name, ETensorType::FLOAT, shape, data);
613 break;
614 }
615 case ETensorType::DOUBLE: {
616 std::shared_ptr<void> data = GetInitializedTensorData<double>(tensorproto, fLength);
617 if (verbose) std::cout << "add DOUBLE initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
618 rmodel.AddInitializedTensor(input_name, ETensorType::DOUBLE, shape, data);
620 break;
621 }
622 case ETensorType::INT32: {
623 std::shared_ptr<void> data = GetInitializedTensorData<int32_t>(tensorproto, fLength);
624 if (verbose) std::cout << "add INT32 initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
625 rmodel.AddInitializedTensor(input_name, ETensorType::INT32, shape, data);
627 break;
628 }
629 case ETensorType::INT64: {
630 std::shared_ptr<void> data = GetInitializedTensorData<int64_t>(tensorproto, fLength);
631 if (verbose) std::cout << "add INT64 initialized tensor " << input_name << " shape " << ConvertShapeToString(shape) << std::endl;
632 rmodel.AddInitializedTensor(input_name, ETensorType::INT64, shape, data);
634 break;
635 }
636 default:
637 throw std::runtime_error("Data type in weight tensor " + graph.initializer(i).name() + " not supported!\n");
638 }
639 }
640
641 // Initial operator order
642 if (verbose) {
643 std::cout << "\nGraph operator list (ONNX order)\n";
644 for (int i = 0; i < graph.node_size(); i++) {
645 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).input_size()
646 << " inputs : {";
647 for (int j = 0; j < graph.node(i).input_size(); j++) {
648 std::cout << graph.node(i).input(j);
649 if (j < graph.node(i).input_size() - 1)
650 std::cout << ", ";
651 }
652 std::cout << " }" << std::endl;
653 }
654 }
655
656 // make order of nodes:
657 if (verbose)
658 std::cout << "\n***********************\nRe-Order graph operator list\n*************************\n";
659 std::vector<size_t> nodesOrder;
660 nodesOrder.reserve(graph.node_size());
661 std::vector<bool> foundNodes(graph.node_size());
662
663 // loop at graph inputs
664 std::map<std::string, int> allInputs;
665 for (int i = 0; i < graph.input_size(); i++) {
666 allInputs[graph.input(i).name()] = -1;
667 }
668 do {
669 auto psize = nodesOrder.size();
670 for (int i = 0; i < graph.node_size(); i++) {
671 if (foundNodes[i])
672 continue;
673 // check if all input exists add to list
674 bool existInputs = true;
675 int input_size = graph.node(i).input_size();
676 // special case for Reshape where shape is input and not a weight tensor
677 if (fVerbose)
678 std::cout << "Checking input of Node " << i << " : " << graph.node(i).name() << std::endl;
679 for (int j = 0; j < input_size; j++) {
680 std::string name = graph.node(i).input(j);
681 // skip empty names
682 if (!name.empty()) {
683 existInputs &= (allInputs.find(name) != allInputs.end() ||
685 if (fVerbose) {
686 std::cout << "\t\t input " << name << " "
687 << bool(allInputs.find(name) != allInputs.end()) << " " <<
689 existInputs << std::endl;
690 }
691 }
692 }
693 if (!existInputs) {
694 if (fVerbose) {
695 std::cout << "skip node " << graph.node(i).op_type() << " " << graph.node(i).name() << " inputs are not existing ";
696 for (int j = 0; j < input_size; j++) {
697 std::cout << graph.node(i).input(j) << " ";
698 }
699 std::cout << std::endl;
700 }
701 continue;
702 }
703
704 // adding node to the currectly ordered list
705 if (verbose)
706 std::cout << "===> New node " << graph.node(i).op_type() << " " << graph.node(i).name() << " order " << i << std::endl;
707
708 nodesOrder.push_back(i);
709 foundNodes[i] = true;
710 // register the outputs
711 for (int j = 0; j < graph.node(i).output_size(); j++) {
712 if (fVerbose) std::cout << "\toutput : " << graph.node(i).output(j) << std::endl;
713 allInputs[graph.node(i).output(j)] = i;
714 }
715 }
716 // no increment in nodes - something wrong
717 if (nodesOrder.size() == psize) {
718 int ilast = nodesOrder.back();
719 std::cout << "cannot find a new node after " << graph.node(ilast).op_type() << " " << graph.node(ilast).name() << std::endl;
720 throw std::runtime_error("TMVA::SOFIE - cannot find a new node ");
721 }
722 } while ((int)nodesOrder.size() < graph.node_size());
723
724
725 // find list of children for each operator (used for fusing oiperators)
726 std::vector<std::vector<int>> nodesChildren(graph.node_size());
727
728 for (int k = 0; k < graph.node_size(); k++) {
729 int i = nodesOrder[k];
730 // compute the number of output for the operators
731 if (graph.node(i).output_size() > 0) nodesChildren[i].reserve(graph.node(i).output_size());
732 for (const auto& output_name : graph.node(i).output()) {
733 // loop on all nodes
734 for (int l = k; l < graph.node_size(); l++) {
735 int j = nodesOrder[l];
736 for (const auto& input_name : graph.node(j).input()) {
737 if (input_name == output_name)
738 nodesChildren[i].push_back(j);
739 }
740 }
741 }
742 }
743
744 // print lit of order operators with list of inputs and list of children nodes
745 if (verbose) {
746 std::cout << "\nGraph operator list (re-ordered)\n";
747 for (int k = 0; k < graph.node_size(); k++) {
748 int i = nodesOrder[k];
749 std::cout << "\tOperator " << i << " : " << graph.node(i).op_type() << " , " << graph.node(i).name() << " input tensors : {";
750 for (int j = 0; j < graph.node(i).input_size(); j++) {
751 std::cout << graph.node(i).input(j);
752 if (j < graph.node(i).input_size() - 1)
753 std::cout << ", ";
754 }
755 std::cout << " } ";
756 std::cout << " children : {";
757 for ( const auto & ichild : nodesChildren[i]) {
758 std::cout << " [ " << ichild << " " << graph.node(ichild).op_type() << " , " << graph.node(ichild).name() << "]";
759 }
760 std::cout << "}" << std::endl;
761 }
762 }
763
764 // fill model with operators
765 if (verbose) {
766 std::cout << "Fill RModel with operators...\n";
767 }
768
769 // we have to record order of node execution separately to
770 // account for fused operators
771 size_t node_order_exec = 0;
772 fFusedOperators = std::vector<bool>(graph.node_size(), false);
773 for (int i = 0; i < graph.node_size(); i++) {
774 std::string op_type = graph.node(nodesOrder[i]).op_type();
775
776 if (verbose) {
777 std::cout << "\t" << i << " " << nodesOrder[i] << " parsing operator " << op_type << std::endl;
778 }
779
780 std::unique_ptr<ROperator> op = ParseOperator(i, graph, nodesOrder, nodesChildren[i]);
781 if (!op) {
782 if (verbose) {
783 std::cout << "\t\tskipping operator since it is fused with previous one" << std::endl;
784 }
785 // for skipping the fused nodes like Add after MatMul
786 continue;
787 }
788 rmodel.AddOperator(std::move(op), node_order_exec++);
789 }
790
791 std::vector<std::string> outputnames;
792 if (verbose)
793 std::cout << "\nParsing Graph output list\n";
794 for (int i = 0; i < graph.output_size(); i++) {
795 if (verbose)
796 std::cout << "\toutput " << i << " name " << graph.output(i).name() << std::endl;
797 outputnames.push_back(graph.output(i).name());
798 }
799 rmodel.AddOutputTensorNameList(outputnames);
800
801 return;
802}
803
804} // namespace SOFIE
805} // namespace Experimental
806} // namespace TMVA
dims_t fShape
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:110
#define malloc
Definition civetweb.c:1575
const_iterator begin() const
const_iterator end() const
void RegisterOperator(const std::string &name, ParserFuncSignature func)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &, const std::vector< int > &)
bool IsRegisteredOperator(const std::string &name)
void CheckGraph(const onnx::GraphProto &g, int &level, std::map< std::string, int > &missingOperators)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string const &filename, bool verbose=false)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< onnx::ModelProto > LoadModel(const std::string &filename)
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
bool CheckModel(std::string filename, bool verbose=false)
std::string Clean_name(std::string input_tensor_name)
ParserFuncSignature ParseIsNaN
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseGreater
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseWhere
Definition ParseWhere.cxx:9
ParserFuncSignature ParseCos
ParserFuncSignature ParseLog
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseEinsum
ParserFuncSignature ParsePool
Definition ParsePool.cxx:9
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseTopK
Definition ParseTopK.cxx:9
ParserFuncSignature ParseMax
ParserFuncSignature ParseEq
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseNot
Definition ParseNot.cxx:9
ParserFuncSignature ParseSlice
Definition ParseSlice.cxx:9
ParserFuncSignature ParseRandom
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseLess
ParserFuncSignature ParseShape
Definition ParseShape.cxx:9
ParserFuncSignature ParseGRU
Definition ParseGRU.cxx:9
ParserFuncSignature ParseMatMul
ParserFuncSignature ParseErf
Definition ParseErf.cxx:9
ParserFuncSignature ParseSub
ParserFuncSignature ParseAdd
ParserFuncSignature ParseNonZero
std::shared_ptr< void > GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t length)
ParserFuncSignature ParseIf
Definition ParseIf.cxx:9
ParserFuncSignature ParseRange
Definition ParseRange.cxx:9
ParserFuncSignature ParseSoftplus
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
Definition ParseRNN.cxx:9
ParserFuncSignature ParseLSTM
Definition ParseLSTM.cxx:9
ParserFuncSignature ParseCast
Definition ParseCast.cxx:9
ParserFuncSignature ParseReciprocal
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuseFuncSignature ParseFuseBatchnormRelu
ParserFuncSignature ParseIsInf
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseGreaterEq
ParserFuncSignature ParseMod
ParserFuncSignature ParseMean
ParserFuncSignature ParseSplit
Definition ParseSplit.cxx:9
ParserFuncSignature ParseConstant
ParserFuncSignature ParseSelu
Definition ParseSelu.cxx:9
ParserFuncSignature ParseLessEq
ParserFuncSignature ParseGatherND
ParserFuncSignature ParseSum
ParserFuncSignature ParseEyeLike
ParserFuncSignature ParsePad
Definition ParsePad.cxx:9
ParserFuncSignature ParseElu
Definition ParseElu.cxx:9
std::string ConvertShapeToString(const std::vector< size_t > &shape)
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
Definition ParseRelu.cxx:9
ParserFuncSignature ParseReduceSum
ParserFuncSignature ParseConv
Definition ParseConv.cxx:9
ParserFuncSignature ParseScatterElements
ParserFuncSignature ParseGemm
Definition ParseGemm.cxx:9
ParserFuncSignature ParseTile
Definition ParseTile.cxx:9
ParserFuncSignature ParseMul
ParserFuseFuncSignature ParseFuseGemmRelu
ParserFuncSignature ParsePow
ParserFuncSignature ParseAbs
ParserFuncSignature ParseSin
ParserFuncSignature ParseReduceSumSquare
ParserFuncSignature ParseTanh
Definition ParseTanh.cxx:9
create variable transformations
Helper templated class for swapping bytes; specializations for N={2,4,8} are provided below.
Definition Byteswap.h:124
static void Copy(onnx::TensorProto *tensor, void *data)
static void Copy(onnx::TensorProto *tensor, void *data)
static void Copy(onnx::TensorProto *tensor, void *data)
static void Copy(onnx::TensorProto *tensor, void *data)
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap
TLine l
Definition textangle.C:4