Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_ONNX.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODELPARSER_ONNX
2#define TMVA_SOFIE_RMODELPARSER_ONNX
3
4#include "TMVA/RModel.hxx"
5
6#include <memory>
7#include <functional>
8#include <unordered_map>
9#include <fstream>
10
11// forward declaration
12namespace onnx {
13class NodeProto;
14class GraphProto;
15class ModelProto;
16class TensorProto;
17} // namespace onnx
18
19namespace TMVA {
20namespace Experimental {
21namespace SOFIE {
22
23class RModelParser_ONNX;
24
26 std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & /*parser*/, const onnx::NodeProto & /*nodeproto*/)>;
28 std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& /*parser*/, const onnx::NodeProto& /*firstnode*/, const onnx::NodeProto& /*secondnode*/)>;
29
31public:
32 struct OperatorsMapImpl;
33
35
36private:
37
38 bool fVerbose = false;
39 // Registered operators
40 std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
41 // Type of the tensors
42 std::unordered_map<std::string, ETensorType> fTensorTypeMap;
43
44 // List of fused operators storing as key the second operator and a value a pair of fusion type and parent operator
45 std::map<int, std::pair<EFusedOp, int>> fFusedOperators;
46
47 // weight data file
48 std::ifstream fDataFile;
49 // filename of model
50 std::string fDataFileName;
51
52
53public:
54 // Register an ONNX operator
55 void RegisterOperator(const std::string &name, ParserFuncSignature func);
56
57 // Check if the operator is registered
58 bool IsRegisteredOperator(const std::string &name);
59
60 // List of registered operators (in alphabetical order)
61 std::vector<std::string> GetRegisteredOperators();
62
63 // Set the type of the tensor
64 void RegisterTensorType(const std::string & /*name*/, ETensorType /*type*/);
65
66 // Check if the type of the tensor is registered
67 bool IsRegisteredTensorType(const std::string & /*name*/);
68
69 // check verbosity
70 bool Verbose() const {
71 return fVerbose;
72 }
73
74 // Get the type of the tensor
75 ETensorType GetTensorType(const std::string &name);
76
77 // Parse the index'th node from the ONNX graph
78 std::unique_ptr<ROperator> ParseOperator(const size_t /*index*/, const onnx::GraphProto & /*graphproto*/,
79 const std::vector<size_t> & /*nodes*/, const std::vector<int> & /* children */);
80
81 // check a graph for missing operators
82 void CheckGraph(const onnx::GraphProto & g, int & level, std::map<std::string, int> & missingOperators);
83
84 // parse the ONNX graph
85 void ParseONNXGraph(RModel & model, const onnx::GraphProto & g, std::string name = "");
86
87 std::unique_ptr<onnx::ModelProto> LoadModel(const std::string &filename);
88 std::unique_ptr<onnx::ModelProto> LoadModel(std::istream &input);
89
90 std::shared_ptr<void> GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t tensor_length, ETensorType type );
91
92public:
93
95
96 RModel Parse(std::string const &filename, bool verbose = false);
97 RModel Parse(std::istream &input, std::string const &name, bool verbose = false);
98
99 // check the model for missing operators - return false in case some operator implementation is missing
100 bool CheckModel(std::string filename, bool verbose = false);
101
102 //set external data full path (needed if external data are not stored in the default modelName.onnx.data)
103 // call this function before parsing
107
109};
110
111} // namespace SOFIE
112} // namespace Experimental
113} // namespace TMVA
114
115#endif // TMVA_SOFIE_RMODELPARSER_ONNX
#define g(i)
Definition RSha256.hxx:105
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:148
void RegisterOperator(const std::string &name, ParserFuncSignature func)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &, const std::vector< int > &)
void SetExternalDataFile(const std::string &dataFileName)
bool IsRegisteredOperator(const std::string &name)
void CheckGraph(const onnx::GraphProto &g, int &level, std::map< std::string, int > &missingOperators)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string const &filename, bool verbose=false)
std::shared_ptr< void > GetInitializedTensorData(onnx::TensorProto *tensorproto, size_t tensor_length, ETensorType type)
std::map< int, std::pair< EFusedOp, int > > fFusedOperators
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< onnx::ModelProto > LoadModel(const std::string &filename)
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
bool CheckModel(std::string filename, bool verbose=false)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
create variable transformations