Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModelParser_ONNX.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODELPARSER_ONNX
2#define TMVA_SOFIE_RMODELPARSER_ONNX
3
4#include "TMVA/RModel.hxx"
5
6#include <memory>
7#include <functional>
8#include <unordered_map>
9
10// forward declaration
11namespace onnx {
12class NodeProto;
13class GraphProto;
14} // namespace onnx
15
16namespace TMVA {
17namespace Experimental {
18namespace SOFIE {
19
20class RModelParser_ONNX;
21
23 std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & /*parser*/, const onnx::NodeProto & /*nodeproto*/)>;
25 std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& /*parser*/, const onnx::NodeProto& /*firstnode*/, const onnx::NodeProto& /*secondnode*/)>;
26
28public:
29 struct OperatorsMapImpl;
30
31private:
32 bool fVerbose = false;
33 // Registered operators
34 std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
35 // Type of the tensors
36 std::unordered_map<std::string, ETensorType> fTensorTypeMap;
37
38public:
39 // Register an ONNX operator
40 void RegisterOperator(const std::string &name, ParserFuncSignature func);
41
42 // Check if the operator is registered
43 bool IsRegisteredOperator(const std::string &name);
44
45 // List of registered operators
46 std::vector<std::string> GetRegisteredOperators();
47
48 // Set the type of the tensor
49 void RegisterTensorType(const std::string & /*name*/, ETensorType /*type*/);
50
51 // Check if the type of the tensor is registered
52 bool IsRegisteredTensorType(const std::string & /*name*/);
53
54 // Get the type of the tensor
55 ETensorType GetTensorType(const std::string &name);
56
57 // Parse the index'th node from the ONNX graph
58 std::unique_ptr<ROperator> ParseOperator(const size_t /*index*/, const onnx::GraphProto & /*graphproto*/,
59 const std::vector<size_t> & /*nodes*/);
60
61public:
62 RModelParser_ONNX() noexcept;
63
64 RModel Parse(std::string filename, bool verbose = false);
65
67};
68
69} // namespace SOFIE
70} // namespace Experimental
71} // namespace TMVA
72
73#endif // TMVA_SOFIE_RMODELPARSER_ONNX
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
void RegisterOperator(const std::string &name, ParserFuncSignature func)
bool IsRegisteredOperator(const std::string &name)
std::unordered_map< std::string, ETensorType > fTensorTypeMap
RModel Parse(std::string filename, bool verbose=false)
std::unique_ptr< ROperator > ParseOperator(const size_t, const onnx::GraphProto &, const std::vector< size_t > &)
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::vector< std::string > GetRegisteredOperators()
std::unique_ptr< OperatorsMapImpl > fOperatorsMapImpl
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &, const onnx::NodeProto &)> ParserFuseFuncSignature
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
create variable transformations