3#include "onnx_proto3.pb.h"
10#include <unordered_map>
15namespace Experimental {
111 tensor->mutable_float_data()->ExtractSubrange(0,
tensor->float_data_size(),
112 static_cast<float *
>(
data));
118 tensor->mutable_double_data()->ExtractSubrange(0,
tensor->double_data_size(),
119 static_cast<double *
>(
data));
125 tensor->mutable_int32_data()->ExtractSubrange(0,
tensor->int32_data_size(),
126 static_cast<int32_t *
>(
data));
132 tensor->mutable_int64_data()->ExtractSubrange(0,
tensor->int64_data_size(),
133 static_cast<int64_t *
>(
data));
144 for (std::size_t k = 0; k <
length; ++k)
254 std::vector<std::string>
ops;
257 ops.emplace_back(it.first);
280std::unique_ptr<ROperator>
283 if (i >= nodes.size())
284 throw std::runtime_error(
"TMVA::SOFIE - Error in parsing ordered operators " + std::to_string(i) +
" is >= " + std::to_string(nodes.size()));
289 std::cout <<
"Parsing operator " <<
op_type << std::endl;
295 if (children.size() == 1) {
296 int idx2 = children.front();
317 }
else if (
nodeproto.op_type() ==
"Gemm") {
323 }
else if (
nodeproto.op_type() ==
"BatchNormalization") {
335 std::cout <<
"operator " <<
op_type <<
" is not supported" << std::endl;
336 throw std::runtime_error(
"TMVA::SOFIE Operator type " +
op_type +
" is not yet supported");
339 std::cout <<
"\tCreating operator " <<
op_type << std::endl;
353 throw std::runtime_error(
"TMVA::SOFIE - Failed to load onnx file " +
filename);
355 const onnx::GraphProto &graph = model->graph();
358 std::time_t
ttime = std::time(0);
369 if (
isep != std::string::npos) {
381 auto model = std::make_unique<onnx::ModelProto>();
383 std::fstream
input(
filename, std::ios::in | std::ios::binary);
384 if (!model->ParseFromIstream(&
input)) {
385 std::cerr <<
"TMVA::SOFIE - Failed to open onnx file " <<
filename << std::endl;
386 return std::unique_ptr<onnx::ModelProto>();
391 std::cout <<
"ONNX Version " << model->ir_version() << std::endl;
393 google::protobuf::ShutdownProtobufLibrary();
400 std::cout <<
"\n" << graph.name() <<
" Graph operator list\n";
401 for (
int i = 0; i < graph.node_size(); i++) {
402 const auto & node = graph.node(i);
403 const std::string
opType = node.op_type();
405 std::cout <<
"\tOperator " << i <<
" : " <<
opType <<
" (" << node.name() <<
"), " << graph.node(i).input_size()
408 std::cout << graph.node(i).input(
j);
409 if (
j < graph.node(i).input_size() - 1)
412 std::cout <<
" }" << std::endl;
418 for (
int j = 0;
j < node.attribute_size();
j++) {
433 if (!model)
return false;
435 const onnx::GraphProto &graph = model->graph();
438 std::cout <<
"\nModel operator list " << model->producer_name() <<
"\n";
445 std::cout <<
"List of missing operators for model loaded from file " <<
filename << std::endl;
447 std::cout <<
op.first <<
" " <<
op.second << std::endl;
451 std::cout <<
"All operators in the loaded model are supported!\n";
463 std::cout <<
"\nParsing Graph - " <<
graphName << std::endl;
466 for (
int i = 0; i < graph.initializer_size(); i++) {
471 std::cout <<
"Parsing model inputs...." << std::endl;
473 for (
int i = 0; i < graph.input_size(); i++) {
475 static_cast<ETensorType>(graph.input(i).type().tensor_type().elem_type()));
478 std::cout <<
"\tgraph input " << i <<
" name " << graph.input(i).name() <<
" type "
479 << graph.input(i).type().tensor_type().elem_type() << std::endl;
493 throw std::runtime_error(
"TMVA::SOFIE data node with no shape restrictions is not supported yet");
494 for (
int j = 0;
j <
valueinfoproto.type().tensor_type().shape().dim_size();
j++) {
497 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
506 }
else if (
valueinfoproto.type().tensor_type().shape().dim(
j).value_case() ==
507 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
512 throw std::runtime_error(
"TMVA::SOFIE ONNX file error: Valueinfoproto " +
input_name +
513 " has neither dim_value nor dim_param! \n");
517 if (
valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
539 std::cout <<
"\nParsing graph initializer list and fill model initialized tensors" << std::endl;
541 for (
int i = 0; i < graph.initializer_size(); i++) {
542 onnx::TensorProto *
tensorproto =
const_cast<onnx::TensorProto *
>(&graph.initializer(i));
543 std::vector<std::size_t> shape;
544 std::size_t fLength = 1;
551 std::string
input_name = graph.initializer(i).name();
554 std::cout <<
"\t initializer " << i <<
" name " <<
input_name <<
" type " << graph.initializer(i).data_type()
591 throw std::runtime_error(
"Data type in weight tensor " + graph.initializer(i).name() +
" not supported!\n");
597 std::cout <<
"\nGraph operator list (ONNX order)\n";
598 for (
int i = 0; i < graph.node_size(); i++) {
599 std::cout <<
"\tOperator " << i <<
" : " << graph.node(i).op_type() <<
" , " << graph.node(i).input_size()
602 std::cout << graph.node(i).input(
j);
603 if (
j < graph.node(i).input_size() - 1)
606 std::cout <<
" }" << std::endl;
612 std::cout <<
"\n***********************\nRe-Order graph operator list\n*************************\n";
615 std::vector<bool>
foundNodes(graph.node_size());
619 for (
int i = 0; i < graph.input_size(); i++) {
624 for (
int i = 0; i < graph.node_size(); i++) {
632 std::cout <<
"Checking input of Node " << i <<
" : " << graph.node(i).name() << std::endl;
634 std::string
name = graph.node(i).input(
j);
640 std::cout <<
"\t\t input " <<
name <<
" "
649 std::cout <<
"skip node " << graph.node(i).op_type() <<
" " << graph.node(i).name() <<
" inputs are not existing ";
651 std::cout << graph.node(i).input(
j) <<
" ";
653 std::cout << std::endl;
660 std::cout <<
"===> New node " << graph.node(i).op_type() <<
" " << graph.node(i).name() <<
" order " << i << std::endl;
666 if (
fVerbose) std::cout <<
"\toutput : " << graph.node(i).output(
j) << std::endl;
673 std::cout <<
"cannot find a new node after " << graph.node(
ilast).op_type() <<
" " << graph.node(
ilast).name() << std::endl;
674 throw std::runtime_error(
"TMVA::SOFIE - cannot find a new node ");
676 }
while ((
int)
nodesOrder.size() < graph.node_size());
680 std::vector<std::vector<int>>
nodesChildren(graph.node_size());
682 for (
int k = 0; k < graph.node_size(); k++) {
685 if (graph.node(i).output_size() > 0)
nodesChildren[i].reserve(graph.node(i).output_size());
686 for (
const auto&
output_name : graph.node(i).output()) {
688 for (
int l = k;
l < graph.node_size();
l++) {
690 for (
const auto&
input_name : graph.node(
j).input()) {
700 std::cout <<
"\nGraph operator list (re-ordered)\n";
701 for (
int k = 0; k < graph.node_size(); k++) {
703 std::cout <<
"\tOperator " << i <<
" : " << graph.node(i).op_type() <<
" , " << graph.node(i).name() <<
" input tensors : {";
704 for (
int j = 0;
j < graph.node(i).input_size();
j++) {
705 std::cout << graph.node(i).input(
j);
706 if (
j < graph.node(i).input_size() - 1)
710 std::cout <<
" children : {";
712 std::cout <<
" [ " <<
ichild <<
" " << graph.node(
ichild).op_type() <<
" , " << graph.node(
ichild).name() <<
"]";
714 std::cout <<
"}" << std::endl;
720 std::cout <<
"Fill RModel with operators...\n";
727 for (
int i = 0; i < graph.node_size(); i++) {
731 std::cout <<
"\t" << i <<
" " <<
nodesOrder[i] <<
" parsing operator " <<
op_type << std::endl;
737 std::cout <<
"\t\tskipping operator since it is fused with previous one" << std::endl;
747 std::cout <<
"\nParsing Graph output list\n";
748 for (
int i = 0; i < graph.output_size(); i++) {
750 std::cout <<
"\toutput " << i <<
" name " << graph.output(i).name() << std::endl;
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
const_iterator begin() const
const_iterator end() const
void RegisterOperator(const std::string &name, ParserFuncSignature func)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &, const std::vector< int > &)
bool IsRegisteredOperator(const std::string &name)
void CheckGraph(const onnx::GraphProto &g, int &level, std::map< std::string, int > &missingOperators)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
RModelParser_ONNX() noexcept
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
bool IsRegisteredTensorType(const std::string &)
void RegisterTensorType(const std::string &, ETensorType)
std::unique_ptr< onnx::ModelProto > LoadModel(std::string filename)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
bool CheckModel(std::string filename, bool verbose=false)
std::vector< bool > fFusedOperators
std::string Clean_name(std::string input_tensor_name)
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseGreater
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseWhere
ParserFuncSignature ParseCos
ParserFuncSignature ParseLog
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseEinsum
ParserFuncSignature ParsePool
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseTopK
ParserFuncSignature ParseMax
ParserFuncSignature ParseEq
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseSlice
ParserFuncSignature ParseRandom
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseLess
ParserFuncSignature ParseShape
ParserFuncSignature ParseGRU
ParserFuncSignature ParseMatMul
ParserFuncSignature ParseErf
ParserFuncSignature ParseSub
ParserFuncSignature ParseAdd
std::shared_ptr< void > GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t length)
ParserFuncSignature ParseIf
ParserFuncSignature ParseRange
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
ParserFuncSignature ParseLSTM
ParserFuncSignature ParseCast
ParserFuncSignature ParseReciprocal
std::string ConvertShapeToString(std::vector< size_t > shape)
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuseFuncSignature ParseFuseBatchnormRelu
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseGreaterEq
ParserFuncSignature ParseMean
ParserFuncSignature ParseSplit
ParserFuncSignature ParseConstant
ParserFuncSignature ParseSelu
ParserFuncSignature ParseLessEq
ParserFuncSignature ParseSum
ParserFuncSignature ParseEyeLike
ParserFuncSignature ParsePad
ParserFuncSignature ParseElu
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
ParserFuncSignature ParseReduceSum
ParserFuncSignature ParseConv
ParserFuncSignature ParseScatterElements
ParserFuncSignature ParseGemm
ParserFuncSignature ParseTile
ParserFuncSignature ParseMul
ParserFuseFuncSignature ParseFuseGemmRelu
ParserFuncSignature ParsePow
ParserFuncSignature ParseAbs
ParserFuncSignature ParseSin
ParserFuncSignature ParseReduceSumSquare
ParserFuncSignature ParseTanh
create variable transformations
Helper templated class for swapping bytes; specializations for N={2,4,8} are provided below.
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap