Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_ONNX.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODELPARSER_ONNX
2#define TMVA_SOFIE_RMODELPARSER_ONNX
3
4#include "TMVA/RModel.hxx"
5
6#include <memory>
7#include <functional>
8#include <unordered_map>
9
10// forward declaration
11namespace onnx {
12class NodeProto;
13class GraphProto;
14} // namespace onnx
15
16namespace TMVA {
17namespace Experimental {
18namespace SOFIE {
19
20class RModelParser_ONNX;
21
23 std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & /*parser*/, const onnx::NodeProto & /*nodeproto*/)>;
25 std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& /*parser*/, const onnx::NodeProto& /*firstnode*/, const onnx::NodeProto& /*secondnode*/)>;
26
28public:
29 struct OperatorsMapImpl;
30
31private:
32
33 bool fVerbose = false;
34 // Registered operators
35 std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
36 // Type of the tensors
37 std::unordered_map<std::string, ETensorType> fTensorTypeMap;
38
39 // all model inputs
40 std::map<std::string, int> allInputs;
41
42
43public:
44 // Register an ONNX operator
45 void RegisterOperator(const std::string &name, ParserFuncSignature func);
46
47 // Check if the operator is registered
48 bool IsRegisteredOperator(const std::string &name);
49
50 // List of registered operators
51 std::vector<std::string> GetRegisteredOperators();
52
53 // Set the type of the tensor
54 void RegisterTensorType(const std::string & /*name*/, ETensorType /*type*/);
55
56 // Check if the type of the tensor is registered
57 bool IsRegisteredTensorType(const std::string & /*name*/);
58
59 // check verbosity
60 bool Verbose() const {
61 return fVerbose;
62 }
63
64 // Get the type of the tensor
65 ETensorType GetTensorType(const std::string &name);
66
67 // Parse the index'th node from the ONNX graph
68 std::unique_ptr<ROperator> ParseOperator(const size_t /*index*/, const onnx::GraphProto & /*graphproto*/,
69 const std::vector<size_t> & /*nodes*/);
70
71 // parse the ONNX graph
72 void ParseONNXGraph(RModel & model, const onnx::GraphProto & g, std::string name = "");
73
74public:
75
76 RModelParser_ONNX() noexcept;
77
78 RModel Parse(std::string filename, bool verbose = false);
79
80
82};
83
84} // namespace SOFIE
85} // namespace Experimental
86} // namespace TMVA
87
88#endif // TMVA_SOFIE_RMODELPARSER_ONNX
#define g(i)
Definition RSha256.hxx:105
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
void RegisterOperator(const std::string &name, ParserFuncSignature func)
bool IsRegisteredOperator(const std::string &name)
void ParseONNXGraph(RModel &model, const onnx::GraphProto &g, std::string name="")
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
create variable transformations