3#include "onnx_proto3.pb.h"
10#include <unordered_map>
15namespace Experimental {
171 std::vector<std::string>
ops;
174 ops.emplace_back(it.first);
195std::unique_ptr<ROperator>
198 if (i >= nodes.size())
199 throw std::runtime_error(
"TMVA::SOFIE - Error in parsing ordered operators " + std::to_string(i) +
" is >= " + std::to_string(nodes.size()));
204 std::cout <<
"Parsing an operator " <<
op_type << std::endl;
207 if (i < nodes.size() - 1) {
208 int idx2 = nodes[i+1];
230 if (idx > 0 &&
op_type ==
"Add") {
231 int idx0 = nodes[i - 1];
240 throw std::runtime_error(
"TMVA::SOFIE Operator type " +
op_type +
" is not yet supported");
243 std::cout <<
"\tCreating operator " <<
op_type << std::endl;
258 if (
isep != std::string::npos) {
262 std::time_t
ttime = std::time(0);
268 onnx::ModelProto model;
273 std::fstream
input(
filename, std::ios::in | std::ios::binary);
274 if (!model.ParseFromIstream(&
input)) {
275 throw std::runtime_error(
"TMVA::SOFIE - Failed to parse onnx file " +
filename);
278 const onnx::GraphProto &
graph = model.graph();
279 google::protobuf::ShutdownProtobufLibrary();
283 std::cout <<
"ONNX Version " << model.ir_version() << std::endl;
287 for (
int i = 0; i <
graph.initializer_size(); i++) {
292 std::cout <<
"Parsing model inputs...." << std::endl;
294 for (
int i = 0; i <
graph.input_size(); i++) {
299 std::cout <<
"\tgraph input " << i <<
" name " <<
graph.input(i).name() <<
" type "
300 <<
graph.input(i).type().tensor_type().elem_type() << std::endl;
311 throw std::runtime_error(
"TMVA::SOFIE Data type in input tensor " +
input_name +
" not supported!\n");
317 throw std::runtime_error(
"TMVA::SOFIE datanode with no shape restrictions is not supported yet");
321 onnx::TensorShapeProto_Dimension::ValueCase::kDimValue) {
323 }
else if (
valueinfoproto.type().tensor_type().shape().dim(
j).value_case() ==
324 onnx::TensorShapeProto_Dimension::ValueCase::kDimParam) {
329 throw std::runtime_error(
"TMVA::SOFIE ONNX file error: Valueinfoproto " +
input_name +
330 " has neither dim_value nor dim_param! \n");
334 if (
valueinfoproto.type().tensor_type().shape().dim_size() == 0) {
356 std::cout <<
"\nParsing graph initializer list and fill model initialized tensors" << std::endl;
358 for (
int i = 0; i <
graph.initializer_size(); i++) {
359 onnx::TensorProto *
tensorproto =
const_cast<onnx::TensorProto *
>(&
graph.initializer(i));
360 std::vector<std::size_t> shape;
361 std::size_t fLength = 1;
371 std::cout <<
"\t initializer " << i <<
" name " <<
input_name <<
" type " <<
graph.initializer(i).data_type()
376 std::shared_ptr<void>
data(
malloc(fLength *
sizeof(
float)), free);
380 std::memcpy(
data.get(),
tensorproto->raw_data().c_str(), fLength *
sizeof(
float));
382 for (std::size_t k = 0; k < fLength; ++k)
383 (
reinterpret_cast<uint32_t *
>(
data.get()))[k] =
388 static_cast<float *
>(
data.get()));
397 std::shared_ptr<void>
data(
malloc(fLength *
sizeof(int64_t)), free);
401 std::memcpy(
data.get(),
tensorproto->raw_data().c_str(), fLength *
sizeof(int64_t));
403 for (std::size_t k = 0; k < fLength; ++k)
404 (
reinterpret_cast<uint64_t *
>(
data.get()))[k] =
409 static_cast<int64_t *
>(
data.get()));
418 throw std::runtime_error(
"Data type in weight tensor " +
graph.initializer(i).name() +
" not supported!\n");
424 std::cout <<
"\nGraph operator list (ONNX order)\n";
425 for (
int i = 0; i <
graph.node_size(); i++) {
426 std::cout <<
"\tOperator " << i <<
" : " <<
graph.node(i).op_type() <<
" , " <<
graph.node(i).input_size()
429 std::cout <<
graph.node(i).input(
j);
430 if (
j <
graph.node(i).input_size() - 1)
433 std::cout <<
" }" << std::endl;
439 std::cout <<
"\nRe-Order graph operator list\n";
445 for (
int i = 0; i <
graph.input_size(); i++) {
450 for (
int i = 0; i <
graph.node_size(); i++) {
464 std::cout <<
graph.node(i).op_type() <<
" input " <<
name <<
" "
473 std::cout <<
"skip op " <<
graph.node(i).op_type() <<
" inputs are ";
475 std::cout <<
graph.node(i).input(
j) <<
" ";
477 std::cout << std::endl;
482 std::cout <<
"\tadd node " <<
graph.node(i).op_type() <<
" order " << i << std::endl;
493 throw std::runtime_error(
"TMVA::SOFIE - cannot find a new node ");
499 std::cout <<
"\nGraph operator list (re-ordered)\n";
500 for (
int k = 0; k <
graph.node_size(); k++) {
502 std::cout <<
"\tOperator " << i <<
" : " <<
graph.node(i).op_type() <<
" , " <<
graph.node(i).input_size()
505 std::cout <<
graph.node(i).input(
j);
506 if (
j <
graph.node(i).input_size() - 1)
509 std::cout <<
" }" << std::endl;
515 std::cout <<
"Fill RModel with operators...\n";
517 for (
int i = 0; i <
graph.node_size(); i++) {
521 std::cout <<
"\t" << i <<
" " <<
nodesOrder[i] <<
" parsing operator " <<
op_type << std::endl;
527 std::cout <<
"\t\tskipping operator since it is fused with previous one" << std::endl;
537 std::cout <<
"\nParsing Graph output list\n";
538 for (
int i = 0; i <
graph.output_size(); i++) {
540 std::cout <<
"\toutput " << i <<
" name " <<
graph.output(i).name() << std::endl;
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
const_iterator end() const
void RegisterOperator(const std::string &name, ParserFuncSignature func)
bool IsRegisteredOperator(const std::string &name)
RModelParser_ONNX() noexcept
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
bool IsRegisteredTensorType(const std::string &)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
std::string Clean_name(std::string input_tensor_name)
ParserFuncSignature ParseSqrt
ParserFuncSignature ParseBatchNormalization
ParserFuncSignature ParseGreater
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
ParserFuncSignature ParseReshape
ParserFuseFuncSignature ParseFuseConvTransposeAdd
ParserFuncSignature ParseReduceMean
ParserFuseFuncSignature ParseFuseMatMulAdd
ParserFuncSignature ParseGather
ParserFuncSignature ParseNeg
ParserFuncSignature ParseLog
ParserFuncSignature ParseLeakyRelu
ParserFuncSignature ParseExp
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParsePool
ParserFuncSignature ParseDiv
ParserFuncSignature ParseLayerNormalization
ParserFuncSignature ParseConcat
ParserFuncSignature ParseMax
ParserFuncSignature ParseEq
ParserFuncSignature ParseIdentity
ParserFuncSignature ParseConvTranspose
ParserFuncSignature ParseReduceProd
ParserFuncSignature ParseSlice
ParserFuncSignature ParseTranspose
ParserFuncSignature ParseLess
ParserFuncSignature ParseShape
ParserFuncSignature ParseGRU
ParserFuncSignature ParseMatMul
ParserFuncSignature ParseErf
ParserFuncSignature ParseSub
ParserFuncSignature ParseReduceSumsquare
ParserFuncSignature ParseAdd
ParserFuncSignature ParseRange
ParserFuncSignature ParseExpand
ParserFuncSignature ParseRNN
ParserFuncSignature ParseLSTM
ParserFuncSignature ParseCast
ParserFuncSignature ParseReciprocal
std::string ConvertShapeToString(std::vector< size_t > shape)
ParserFuncSignature ParseSigmoid
ParserFuseFuncSignature ParseFuseConvAdd
ParserFuncSignature ParseSoftmax
ParserFuncSignature ParseGreaterEq
ParserFuncSignature ParseMean
ParserFuncSignature ParseSelu
ParserFuncSignature ParseLessEq
ParserFuncSignature ParseSum
ParserFuncSignature ParseEyeLike
ParserFuncSignature ParseElu
ParserFuncSignature ParseMin
ParserFuncSignature ParseRelu
ParserFuncSignature ParseConv
ParserFuncSignature ParseGemm
ParserFuncSignature ParseMul
ParserFuncSignature ParsePow
ParserFuncSignature ParseTanh
create variable transformations
std::unordered_map< std::string, ParserFuncSignature > fOperatorsMap