3#include "onnx_proto3.pb.h"
10#include <unordered_map>
15namespace Experimental {
119 tensor->mutable_float_data()->ExtractSubrange(0,
tensor->float_data_size(),
120 static_cast<float *
>(
data));
126 tensor->mutable_double_data()->ExtractSubrange(0,
tensor->double_data_size(),
127 static_cast<double *
>(
data));
133 tensor->mutable_int32_data()->ExtractSubrange(0,
tensor->int32_data_size(),
134 static_cast<int32_t *
>(
data));
140 tensor->mutable_int64_data()->ExtractSubrange(0,
tensor->int64_data_size(),
141 static_cast<int64_t *
>(
data));
152 for (std::size_t k = 0; k <
length; ++k)
271 std::vector<std::string>
ops;
274 ops.emplace_back(it.first);
297std::unique_ptr<ROperator>
300 if (i >= nodes.size())
301 throw std::runtime_error(
"TMVA::SOFIE - Error in parsing ordered operators " + std::to_string(i) +
" is >= " + std::to_string(nodes.size()));
306 std::cout <<
"Parsing operator " <<
op_type << std::endl;
312 if (children.size() == 1) {
313 int idx2 = children.front();
334 }
else if (
nodeproto.op_type() ==
"Gemm") {
340 }
else if (
nodeproto.op_type() ==
"BatchNormalization") {
352 std::cout <<
"operator " <<
op_type <<
" is not supported" << std::endl;
353 throw std::runtime_error(
"TMVA::SOFIE Operator type " +
op_type +
" is not yet supported");
356 std::cout <<
"\tCreating operator " <<
op_type << std::endl;
370 throw std::runtime_error(
"TMVA::SOFIE - Failed to load onnx file " +
filename);
372 const onnx::GraphProto &graph = model->graph();
375 std::time_t
ttime = std::time(0);
386 if (
isep != std::string::npos) {
403 throw std::runtime_error(
"TMVA::SOFIE - Failed to parse ONNX model from input stream");
405 const onnx::GraphProto &graph = model->graph();
407 std::time_t
ttime = std::time(0);
417 std::fstream
input(
filename, std::ios::in | std::ios::binary);
419 std::cerr <<
"TMVA::SOFIE - Failed to open onnx file " <<
filename << std::endl;
428 auto model = std::make_unique<onnx::ModelProto>();
430 if (!model->ParseFromIstream(&
input)) {
431 std::cerr <<
"TMVA::SOFIE - Failed to parse ONNX model from input stream" << std::endl;
437 std::cout <<
"ONNX Version " << model->ir_version() << std::endl;
439 google::protobuf::ShutdownProtobufLibrary();
446 std::cout <<
"\n" << graph.name() <<
" Graph operator list\n";
447 for (
int i = 0; i < graph.node_size(); i++) {
448 const auto & node = graph.node(i);
449 const std::string
opType = node.op_type();
451 std::cout <<
"\tOperator " << i <<
" : " <<
opType <<
" (" << node.name() <<
"), " << graph.node(i).input_size()
454 std::cout << graph.node(i).input(
j);
455 if (
j < graph.node(i).input_size() - 1)
458 std::cout <<
" }" << std::endl;
464 for (
int j = 0;
j < node.attribute_size();
j++) {
479 if (!model)
return false;
481 const onnx::GraphProto &graph = model->graph();
484 std::cout <<
"\nModel operator list " << model->producer_name() <<
"\n";
491 std::cout <<
"List of missing operators for model loaded from file " <<
filename << std::endl;
493 std::cout <<
op.first <<
" " <<
op.second << std::endl;
497 std::cout <<
"All operators in the loaded model are supported!\n";
509 std::cout <<
"\nParsing Graph - " <<
graphName << std::endl;
512 for (
int i = 0; i < graph.initializer_size(); i++) {
517 std::cout <<
"Parsing model inputs...." << std::endl;
519 for (
int i = 0; i < graph.input_size(); i++) {
521 static_cast<ETensorType>(graph.input(i).type().tensor_type().elem_type()));
524 std::cout <<
"\tgraph input " << i <<
" name " << graph.input(i).name() <<
" type "
525 << graph.input(i).type().tensor_type().elem_type() << std::endl;
539 throw std::runtime_error(
"TMVA::SOFIE data node with no shape restrictions is not supported yet");
540 for (
int j = 0;
j <
valueinfoproto.type().tensor_type().shape().dim_size();
j++) {
543 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
552 }
else if (
valueinfoproto.type().tensor_type().shape().dim(
j).value_case() ==
553 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
558 throw std::runtime_error(
"TMVA::SOFIE ONNX file error: Valueinfoproto " +
input_name +
559 " has neither dim_value nor dim_param! \n");
563 if (
valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
585 std::cout <<
"\nParsing graph initializer list and fill model initialized tensors" << std::endl;
587 for (
int i = 0; i < graph.initializer_size(); i++) {
588 onnx::TensorProto *
tensorproto =
const_cast<onnx::TensorProto *
>(&graph.initializer(i));
589 std::vector<std::size_t> shape;
590 std::size_t fLength = 1;
597 std::string
input_name = graph.initializer(i).name();
600 std::cout <<
"\t initializer " << i <<
" name " <<
input_name <<
" type " << graph.initializer(i).data_type()
637 throw std::runtime_error(
"Data type in weight tensor " + graph.initializer(i).name() +
" not supported!\n");
643 std::cout <<
"\nGraph operator list (ONNX order)\n";
644 for (
int i = 0; i < graph.node_size(); i++) {
645 std::cout <<
"\tOperator " << i <<
" : " << graph.node(i).op_type() <<
" , " << graph.node(i).input_size()
648 std::cout << graph.node(i).input(
j);
649 if (
j < graph.node(i).input_size() - 1)
652 std::cout <<
" }" << std::endl;
658 std::cout <<
"\n***********************\nRe-Order graph operator list\n*************************\n";
661 std::vector<bool>
foundNodes(graph.node_size());
665 for (
int i = 0; i < graph.input_size(); i++) {
670 for (
int i = 0; i < graph.node_size(); i++) {
678 std::cout <<
"Checking input of Node " << i <<
" : " << graph.node(i).name() << std::endl;
680 std::string
name = graph.node(i).input(
j);
686 std::cout <<
"\t\t input " <<
name <<
" "
695 std::cout <<
"skip node " << graph.node(i).op_type() <<
" " << graph.node(i).name() <<
" inputs are not existing ";
697 std::cout << graph.node(i).input(
j) <<
" ";
699 std::cout << std::endl;
706 std::cout <<
"===> New node " << graph.node(i).op_type() <<
" " << graph.node(i).name() <<
" order " << i << std::endl;
712 if (
fVerbose) std::cout <<
"\toutput : " << graph.node(i).output(
j) << std::endl;
719 std::cout <<
"cannot find a new node after " << graph.node(
ilast).op_type() <<
" " << graph.node(
ilast).name() << std::endl;
720 throw std::runtime_error(
"TMVA::SOFIE - cannot find a new node ");
722 }
while ((
int)
nodesOrder.size() < graph.node_size());
726 std::vector<std::vector<int>>
nodesChildren(graph.node_size());
728 for (
int k = 0; k < graph.node_size(); k++) {
731 if (graph.node(i).output_size() > 0)
nodesChildren[i].reserve(graph.node(i).output_size());
732 for (
const auto&
output_name : graph.node(i).output()) {
734 for (
int l = k;
l < graph.node_size();
l++) {
736 for (
const auto&
input_name : graph.node(
j).input()) {
746 std::cout <<
"\nGraph operator list (re-ordered)\n";
747 for (
int k = 0; k < graph.node_size(); k++) {
749 std::cout <<
"\tOperator " << i <<
" : " << graph.node(i).op_type() <<
" , " << graph.node(i).name() <<
" input tensors : {";
750 for (
int j = 0;
j < graph.node(i).input_size();
j++) {
751 std::cout << graph.node(i).input(
j);
752 if (
j < graph.node(i).input_size() - 1)
756 std::cout <<
" children : {";
758 std::cout <<
" [ " <<
ichild <<
" " << graph.node(
ichild).op_type() <<
" , " << graph.node(
ichild).name() <<
"]";
760 std::cout <<
"}" << std::endl;
766 std::cout <<
"Fill RModel with operators...\n";
773 for (
int i = 0; i < graph.node_size(); i++) {
777 std::cout <<
"\t" << i <<
" " <<
nodesOrder[i] <<
" parsing operator " <<
op_type << std::endl;
783 std::cout <<
"\t\tskipping operator since it is fused with previous one" << std::endl;
793 std::cout <<
"\nParsing Graph output list\n";
794 for (
int i = 0; i < graph.output_size(); i++) {
796 std::cout <<
"\toutput " << i <<
" name " << graph.output(i).name() << std::endl;
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
const_iterator begin() const
const_iterator end() const
void RegisterOperator(const std::string &name, ParserFuncSignature func)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &, const std::vector< int > &)
bool IsRegisteredOperator(const std::string &name)
void CheckGraph(const onnx::GraphProto &g, int &level, std::map< std::string, int > &missingOperators)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
RModelParser_ONNX() noexcept
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string const &filename, bool verbose=false)
bool IsRegisteredTensorType(const std::string &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< onnx::ModelProto > LoadModel(const std::string &filename)
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
bool CheckModel(std::string filename, bool verbose=false)
std::vector< bool > fFusedOperators
std::string Clean_name(std::string input_tensor_name)
ParserFuncSignature ParseIsNaN
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseGreater
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseWhere
ParserFuncSignature ParseCos
ParserFuncSignature ParseLog
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseEinsum
ParserFuncSignature ParsePool
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseTopK
ParserFuncSignature ParseMax
ParserFuncSignature ParseEq
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseNot
ParserFuncSignature ParseSlice
ParserFuncSignature ParseRandom
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseLess
ParserFuncSignature ParseShape
ParserFuncSignature ParseGRU
ParserFuncSignature ParseMatMul
ParserFuncSignature ParseErf
ParserFuncSignature ParseSub
ParserFuncSignature ParseAdd
ParserFuncSignature ParseNonZero
std::shared_ptr< void > GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t length)
ParserFuncSignature ParseIf
ParserFuncSignature ParseRange
ParserFuncSignature ParseSoftplus
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
ParserFuncSignature ParseLSTM
ParserFuncSignature ParseCast
ParserFuncSignature ParseReciprocal
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuseFuncSignature ParseFuseBatchnormRelu
ParserFuncSignature ParseIsInf
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseGreaterEq
ParserFuncSignature ParseMod
ParserFuncSignature ParseMean
ParserFuncSignature ParseSplit
ParserFuncSignature ParseConstant
ParserFuncSignature ParseSelu
ParserFuncSignature ParseLessEq
ParserFuncSignature ParseGatherND
ParserFuncSignature ParseSum
ParserFuncSignature ParseEyeLike
ParserFuncSignature ParsePad
ParserFuncSignature ParseElu
std::string ConvertShapeToString(const std::vector< size_t > &shape)
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
ParserFuncSignature ParseReduceSum
ParserFuncSignature ParseConv
ParserFuncSignature ParseScatterElements
ParserFuncSignature ParseGemm
ParserFuncSignature ParseTile
ParserFuncSignature ParseMul
ParserFuseFuncSignature ParseFuseGemmRelu
ParserFuncSignature ParsePow
ParserFuncSignature ParseAbs
ParserFuncSignature ParseSin
ParserFuncSignature ParseReduceSumSquare
ParserFuncSignature ParseTanh
create variable transformations
Helper templated class for swapping bytes; specializations for N={2,4,8} are provided below.
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap