3#include "onnx_proto3.pb.h"
10#include <unordered_map>
15namespace Experimental {
100 static void Copy(onnx::TensorProto * tensor,
void *
data) {
101 tensor->mutable_float_data()->ExtractSubrange(0, tensor->float_data_size(),
102 static_cast<float *
>(
data));
107 static void Copy(onnx::TensorProto * tensor,
void *
data) {
108 tensor->mutable_double_data()->ExtractSubrange(0, tensor->double_data_size(),
109 static_cast<double *
>(
data));
114 static void Copy(onnx::TensorProto * tensor,
void *
data) {
115 tensor->mutable_int32_data()->ExtractSubrange(0, tensor->int32_data_size(),
116 static_cast<int32_t *
>(
data));
121 static void Copy(onnx::TensorProto * tensor,
void *
data) {
122 tensor->mutable_int64_data()->ExtractSubrange(0, tensor->int64_data_size(),
123 static_cast<int64_t *
>(
data));
130 if (!tensorproto->raw_data().empty()) {
132 std::memcpy(
data.get(), tensorproto->raw_data().c_str(),
length *
sizeof(T));
134 for (std::size_t k = 0; k <
length; ++k)
135 (
reinterpret_cast<typename
RByteSwap<sizeof(T)
>::value_type *>(
data.get()))[k] =
136 RByteSwap<
sizeof(T)>::bswap((
reinterpret_cast<const typename
RByteSwap<sizeof(T)
>::value_type *>(tensorproto->raw_data().c_str()))[k]);
233 std::vector<std::string> ops;
236 ops.emplace_back(it.first);
257std::unique_ptr<ROperator>
260 if (i >= nodes.size())
261 throw std::runtime_error(
"TMVA::SOFIE - Error in parsing ordered operators " + std::to_string(i) +
" is >= " + std::to_string(nodes.size()));
263 const auto &nodeproto = graphproto.node(idx);
264 const std::string op_type = nodeproto.op_type();
266 std::cout <<
"Parsing operator " << op_type << std::endl;
269 if (i < nodes.size() - 1) {
270 int idx2 = nodes[i+1];
271 if (op_type ==
"MatMul") {
273 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() ==
"Add") {
279 }
else if (nodeproto.op_type() ==
"Conv" || nodeproto.op_type() ==
"ConvTranspose") {
281 if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() ==
"Add") {
282 if (nodeproto.op_type() ==
"Conv") {
283 return ParseFuseConvAdd(*
this, graphproto.node(idx), graphproto.node(idx2));
292 if (idx > 0 && op_type ==
"Add") {
293 int idx0 = nodes[i - 1];
294 if (graphproto.node(idx0).op_type() ==
"MatMul")
296 else if (graphproto.node(idx0).op_type() ==
"ConvTranspose")
302 std::cout <<
"operator " << op_type <<
" is not supported" << std::endl;
303 throw std::runtime_error(
"TMVA::SOFIE Operator type " + op_type +
" is not yet supported");
306 std::cout <<
"\tCreating operator " << op_type << std::endl;
308 return it->second(*
this, nodeproto);
320 std::string filename_nodir =
filename;
321 if (isep != std::string::npos) {
326 GOOGLE_PROTOBUF_VERIFY_VERSION;
328 onnx::ModelProto model;
333 std::fstream
input(
filename, std::ios::in | std::ios::binary);
334 if (!model.ParseFromIstream(&
input)) {
335 throw std::runtime_error(
"TMVA::SOFIE - Failed to parse onnx file " +
filename);
338 const onnx::GraphProto &
graph = model.graph();
339 google::protobuf::ShutdownProtobufLibrary();
343 std::cout <<
"ONNX Version " << model.ir_version() << std::endl;
346 std::time_t ttime = std::time(0);
347 std::tm *gmt_time = std::gmtime(&ttime);
348 std::string parsetime(std::asctime(gmt_time));
350 RModel rmodel(filename_nodir, parsetime);
359 if (graphName.empty())
360 graphName =
graph.name();
363 std::cout <<
"\nParsing Graph - " << graphName << std::endl;
365 std::unordered_set<std::string> initializer_names;
366 for (
int i = 0; i <
graph.initializer_size(); i++) {
367 initializer_names.insert(
graph.initializer(i).name());
371 std::cout <<
"Parsing model inputs...." << std::endl;
373 for (
int i = 0; i <
graph.input_size(); i++) {
378 std::cout <<
"\tgraph input " << i <<
" name " <<
graph.input(i).name() <<
" type "
379 <<
graph.input(i).type().tensor_type().elem_type() << std::endl;
381 if (initializer_names.find(
graph.input(i).name()) != initializer_names.end())
385 const onnx::ValueInfoProto &valueinfoproto =
graph.input(i);
386 std::string input_name = valueinfoproto.name();
390 throw std::runtime_error(
"TMVA::SOFIE Data type in input tensor " + input_name +
" not supported!\n");
394 bool existParam =
false;
395 if (!valueinfoproto.type().tensor_type().has_shape())
396 throw std::runtime_error(
"TMVA::SOFIE data node with no shape restrictions is not supported yet");
397 for (
int j = 0; j < valueinfoproto.type().tensor_type().shape().dim_size(); j++) {
399 if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
400 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
401 int dim_value = valueinfoproto.type().tensor_type().shape().
dim(j).dim_value();
409 }
else if (valueinfoproto.type().tensor_type().shape().dim(j).value_case() ==
410 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
413 dim.
param = valueinfoproto.type().tensor_type().shape().dim(j).dim_param();
415 throw std::runtime_error(
"TMVA::SOFIE ONNX file error: Valueinfoproto " + input_name +
416 " has neither dim_value nor dim_param! \n");
420 if (valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
427 std::vector<size_t> fShape_sizet;
429 fShape_sizet.push_back(j.dim);
439 std::map<std::string, int> allInitializedTensors;
442 std::cout <<
"\nParsing graph initializer list and fill model initialized tensors" << std::endl;
444 for (
int i = 0; i <
graph.initializer_size(); i++) {
445 onnx::TensorProto *tensorproto =
const_cast<onnx::TensorProto *
>(&
graph.initializer(i));
446 std::vector<std::size_t> shape;
447 std::size_t fLength = 1;
448 for (
int j = 0; j < tensorproto->dims_size(); j++) {
449 shape.push_back(tensorproto->dims(j));
450 fLength *= tensorproto->dims(j);
454 std::string input_name =
graph.initializer(i).name();
457 std::cout <<
"\t initializer " << i <<
" name " << input_name <<
" type " <<
graph.initializer(i).data_type()
461 auto tensor_type =
static_cast<ETensorType>(
graph.initializer(i).data_type());
464 switch (tensor_type) {
466 std::shared_ptr<void>
data = GetInitializedTensorData<float>(tensorproto, fLength);
467 if (verbose) std::cout <<
"add FLOAT initialized tensor " << input_name <<
" shape " <<
ConvertShapeToString(shape) << std::endl;
469 allInitializedTensors[input_name] = i;
473 std::shared_ptr<void>
data = GetInitializedTensorData<double>(tensorproto, fLength);
474 if (verbose) std::cout <<
"add DOUBLE initialized tensor " << input_name <<
" shape " <<
ConvertShapeToString(shape) << std::endl;
476 allInitializedTensors[input_name] = i;
480 std::shared_ptr<void>
data = GetInitializedTensorData<int32_t>(tensorproto, fLength);
481 if (verbose) std::cout <<
"add INT32 initialized tensor " << input_name <<
" shape " <<
ConvertShapeToString(shape) << std::endl;
483 allInitializedTensors[input_name] = i;
487 std::shared_ptr<void>
data = GetInitializedTensorData<int64_t>(tensorproto, fLength);
488 if (verbose) std::cout <<
"add INT64 initialized tensor " << input_name <<
" shape " <<
ConvertShapeToString(shape) << std::endl;
490 allInitializedTensors[input_name] = i;
494 throw std::runtime_error(
"Data type in weight tensor " +
graph.initializer(i).name() +
" not supported!\n");
500 std::cout <<
"\nGraph operator list (ONNX order)\n";
501 for (
int i = 0; i <
graph.node_size(); i++) {
502 std::cout <<
"\tOperator " << i <<
" : " <<
graph.node(i).op_type() <<
" , " <<
graph.node(i).input_size()
504 for (
int j = 0; j <
graph.node(i).input_size(); j++) {
505 std::cout <<
graph.node(i).input(j);
506 if (j <
graph.node(i).input_size() - 1)
509 std::cout <<
" }" << std::endl;
515 std::cout <<
"\nRe-Order graph operator list\n";
516 std::vector<size_t> nodesOrder;
517 nodesOrder.reserve(
graph.node_size());
518 std::vector<bool> foundNodes(
graph.node_size());
521 for (
int i = 0; i <
graph.input_size(); i++) {
525 auto psize = nodesOrder.size();
526 for (
int i = 0; i <
graph.node_size(); i++) {
530 bool existInputs =
true;
531 int input_size =
graph.node(i).input_size();
534 std::cout <<
"Checking input of Node " << i <<
" : " <<
graph.node(i).name() << std::endl;
535 for (
int j = 0; j < input_size; j++) {
536 std::string
name =
graph.node(i).input(j);
540 allInitializedTensors.find(
name) != allInitializedTensors.end());
542 std::cout <<
"\t\t input " <<
name <<
" "
544 bool(allInitializedTensors.find(
name) != allInitializedTensors.end()) <<
545 existInputs << std::endl;
551 std::cout <<
"skip node " <<
graph.node(i).op_type() <<
" " <<
graph.node(i).name() <<
" inputs are not existing ";
552 for (
int j = 0; j < input_size; j++) {
553 std::cout <<
graph.node(i).input(j) <<
" ";
555 std::cout << std::endl;
560 std::cout <<
"===> New node " <<
graph.node(i).op_type() <<
" " <<
graph.node(i).name() <<
" order " << i << std::endl;
562 nodesOrder.push_back(i);
563 foundNodes[i] =
true;
565 for (
int j = 0; j <
graph.node(i).output_size(); j++) {
566 if (
fVerbose) std::cout <<
"\toutput : " <<
graph.node(i).output(j) << std::endl;
571 if (nodesOrder.size() == psize) {
572 int ilast = nodesOrder.back();
573 std::cout <<
"cannot find a new node after " <<
graph.node(ilast).op_type() <<
" " <<
graph.node(ilast).name() << std::endl;
574 throw std::runtime_error(
"TMVA::SOFIE - cannot find a new node ");
576 }
while ((
int)nodesOrder.size() <
graph.node_size());
580 std::cout <<
"\nGraph operator list (re-ordered)\n";
581 for (
int k = 0; k <
graph.node_size(); k++) {
582 int i = nodesOrder[k];
583 std::cout <<
"\tOperator " << i <<
" : " <<
graph.node(i).op_type() <<
" , " <<
graph.node(i).input_size()
585 for (
int j = 0; j <
graph.node(i).input_size(); j++) {
586 std::cout <<
graph.node(i).input(j);
587 if (j <
graph.node(i).input_size() - 1)
590 std::cout <<
" }" << std::endl;
596 std::cout <<
"Fill RModel with operators...\n";
598 for (
int i = 0; i <
graph.node_size(); i++) {
599 std::string op_type =
graph.node(nodesOrder[i]).op_type();
602 std::cout <<
"\t" << i <<
" " << nodesOrder[i] <<
" parsing operator " << op_type << std::endl;
608 std::cout <<
"\t\tskipping operator since it is fused with previous one" << std::endl;
616 std::vector<std::string> outputnames;
618 std::cout <<
"\nParsing Graph output list\n";
619 for (
int i = 0; i <
graph.output_size(); i++) {
621 std::cout <<
"\toutput " << i <<
" name " <<
graph.output(i).name() << std::endl;
622 outputnames.push_back(
graph.output(i).name());
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
void RegisterOperator(const std::string &name, ParserFuncSignature func)
bool IsRegisteredOperator(const std::string &name)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
RModelParser_ONNX() noexcept
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
bool IsRegisteredTensorType(const std::string &)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::map< std::string, int > allInputs
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector< Dim > shape)
void AddOutputTensorNameList(std::vector< std::string > output_tensor_names)
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
void AddInputTensorName(std::string name)
void AddOperator(std::unique_ptr< ROperator > op, int order_execution=-1)
std::string Clean_name(std::string input_tensor_name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseGreater
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseLog
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
ParserFuncSignature ParsePool
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseTopK
ParserFuncSignature ParseMax
ParserFuncSignature ParseEq
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseSlice
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseLess
ParserFuncSignature ParseShape
ParserFuncSignature ParseGRU
ParserFuncSignature ParseMatMul
ParserFuncSignature ParseErf
ParserFuncSignature ParseSub
ParserFuncSignature ParseAdd
std::shared_ptr< void > GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t length)
ParserFuncSignature ParseIf
ParserFuncSignature ParseRange
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseLSTM
ParserFuncSignature ParseCast
ParserFuncSignature ParseReciprocal
std::string ConvertShapeToString(std::vector< size_t > shape)
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseGreaterEq
ParserFuncSignature ParseMean
ParserFuncSignature ParseSplit
ParserFuncSignature ParseConstant
ParserFuncSignature ParseSelu
ParserFuncSignature ParseLessEq
ParserFuncSignature ParseSum
ParserFuncSignature ParseEyeLike
ParserFuncSignature ParseElu
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
ParserFuncSignature ParseReduceSum
ParserFuncSignature ParseConv
ParserFuncSignature ParseGemm
ParserFuncSignature ParseTile
ParserFuncSignature ParseMul
ParserFuncSignature ParsePow
ParserFuncSignature ParseReduceSumSquare
ParserFuncSignature ParseTanh
create variable transformations
Helper templated class for swapping bytes; specializations for N={2,4,8} are provided below.
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap