Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODEL
2#define TMVA_SOFIE_RMODEL
3
6#include "TMVA/ROperator.hxx"
7
8namespace TMVA {
9namespace Experimental {
10namespace SOFIE {
11
12class RModel final : public RModel_Base {
13
14private:
15 std::unordered_map<std::string, InputTensorInfo>
16 fInputTensorInfos; // input tensors where shape is not defined or other graph inputs?
17 std::unordered_map<std::string, TensorInfo> fReadyInputTensorInfos; // input tensors where shape is full defined
18 std::unordered_map<std::string, InitializedTensor> fInitializedTensors;
19 std::unordered_map<std::string, TensorInfo> fIntermediateTensorInfos;
20 std::unordered_map<std::string, DynamicTensorInfo> fDynamicTensorInfos;
21 std::unordered_map<std::string, std::string>
22 fShapeParams; // parameters defining the dynamic shape (e.g. batch size), store also its default value
23 std::vector<std::string> fOutputTensorNames;
24 std::vector<std::string> fInputTensorNames; // input tensor names using ONNX order
25
26 std::vector<std::unique_ptr<ROperator>> fOperators;
27
28 const std::string SP = " ";
29
30public:
31 // Rule of five: explicitly define move semantics, disallow copy
32 RModel(RModel &&other);
33 RModel &operator=(RModel &&other);
34 RModel(const RModel &other) = delete;
35 RModel &operator=(const RModel &other) = delete;
36 ~RModel() = default;
37
38 /**
39 Default constructor. Needed to allow serialization of ROOT objects. See
40 https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
41 */
42 RModel() = default;
43 RModel(std::string name, std::string parsedtime) : RModel_Base(name, parsedtime) {}
44
45 // For GNN Functions usage
46 RModel(std::string function_name) : RModel_Base(function_name) {}
47
48 const std::vector<size_t> &GetTensorShape(std::string name);
49 std::vector<Dim> GetDynamicTensorShape(std::string name);
50 const ETensorType &GetTensorType(std::string name);
51
52 bool CheckIfTensorAlreadyExist(std::string tensor_name);
53 void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<Dim> shape);
54 void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<size_t> shape);
55 void AddOperator(std::unique_ptr<ROperator> op, int order_execution = -1);
56 void AddOperatorReference(ROperator *op, int order_execution = -1)
57 {
58 std::unique_ptr<ROperator> tmp(op);
59 AddOperator(std::move(tmp), order_execution);
60 }
61 void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape,
62 std::shared_ptr<void> data);
63 void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape,
64 std::shared_ptr<void> data);
65
66
67 template <typename T>
68 void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape, T *raw_data)
69 {
70 int size = 1;
71 for (auto item : shape) {
72 size *= (int)item;
73 }
74 std::shared_ptr<void> data(malloc(size * sizeof(T)), free);
75 std::memcpy(data.get(), raw_data, size * sizeof(T));
76 AddInitializedTensor(tensor_name, type, shape, data);
77 }
78
79 // set a flag to indicate tensor does not need to be written in a weight file
80 // (e.g. shape tensors used as input to define a shape (in Reshape))
81 void SetNotWritableInitializedTensor(const std::string & tensor_name);
82
83 // Check if a tensor is initialized
84 bool IsInitializedTensor(const std::string &name) const;
85 bool IsDynamicTensor(const std::string &name) const;
86 bool IsInputTensor(const std::string &name) const;
87
88 // Add intermediate tensor
89 void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector<Dim> dim_shape);
90 void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape);
91 // Add an intermediate dynamic tensor
92 void AddDynamicTensor(std::string tensor_name, ETensorType type, std::vector<Dim> shape);
93
94 void AddInputTensorName(std::string name);
95 void AddOutputTensorNameList(std::vector<std::string> output_tensor_names);
96 void
97 UpdateOutputTensorList(std::vector<std::string> curr_output_tensor, std::vector<std::string> modify_output_tensor);
98 void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape,
99 std::shared_ptr<void> data);
100 std::shared_ptr<void> GetInitializedTensorData(std::string tensor_name);
101
102 void Initialize(int batchSize = -1, bool verbose = false);
106 void GenerateOutput();
107 void Generate(std::underlying_type_t<Options> options, int batchSize = -1, long pos = 0, bool verbose = false);
108 void Generate(Options options = Options::kDefault, int batchSize = -1, int pos = 0, bool verbose = false)
109 {
110 Generate(static_cast<std::underlying_type_t<Options>>(options), batchSize, pos, verbose);
111 }
112
113 const std::vector<std::string> &GetInputTensorNames() const { return fInputTensorNames; }
114 const std::vector<std::string> &GetOutputTensorNames() const { return fOutputTensorNames; }
115
117 long WriteInitializedTensorsToFile(std::string filename = "");
118
120 void PrintOutputTensors();
121 void OutputGenerated(std::string filename = "", bool append = false);
122 std::vector<std::string> GetOutputTensorNames() { return fOutputTensorNames; }
123 void SetFilename(std::string filename) { fName = filename; }
124
125 /*
126 template <typename T>
127 void AddInitializedTensor(std::string tensor_name, RTensor<T> new_tensor){
128 //a view only
129 T obj;
130 if (fInitializedTensors.find(tensor_name) != fInitializedTensors.end()){
131 throw std::runtime_error("TMVA-SOFIE: initialized tensor with name " + tensor_name + " already exists \n");
132 }
133 InitializedTensor new_tensor_ {GetTemplatedType(obj), new_tensor.GetShape() ,
134 static_cast<void>(new_tensor.GetData())}; fInitializedTensors[tensor_name] = new_tensor_;
135 }
136 */
137
140 void PrintDynamicTensors();
141 void HeadInitializedTensors(std::string name, int n_print = 50);
142
143 bool UseSession() const { return fUseSession; }
144
145 // Use the ClassDef macro to allow definition of custom streaming
147};
148
149} // namespace SOFIE
150} // namespace Experimental
151} // namespace TMVA
152
153#endif // TMVA_SOFIE_RMODEL
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char name[80]
Definition TGX11.cxx:110
#define free
Definition civetweb.c:1539
#define malloc
Definition civetweb.c:1536
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:91
std::unordered_map< std::string, DynamicTensorInfo > fDynamicTensorInfos
Definition RModel.hxx:20
bool IsDynamicTensor(const std::string &name) const
Definition RModel.cxx:186
RModel(const RModel &other)=delete
const std::vector< std::string > & GetOutputTensorNames() const
Definition RModel.hxx:114
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:196
std::vector< Dim > GetDynamicTensorShape(std::string name)
Definition RModel.cxx:79
void AddOperatorReference(ROperator *op, int order_execution=-1)
Definition RModel.hxx:56
RModel(std::string function_name)
Definition RModel.hxx:46
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:116
std::vector< std::unique_ptr< ROperator > > fOperators
Definition RModel.hxx:26
void OutputGenerated(std::string filename="", bool append=false)
Definition RModel.cxx:991
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector< Dim > shape)
Definition RModel.cxx:125
std::unordered_map< std::string, TensorInfo > fIntermediateTensorInfos
Definition RModel.hxx:19
void AddOutputTensorNameList(std::vector< std::string > output_tensor_names)
Definition RModel.cxx:234
std::unordered_map< std::string, TensorInfo > fReadyInputTensorInfos
Definition RModel.hxx:17
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:171
void AddDynamicTensor(std::string tensor_name, ETensorType type, std::vector< Dim > shape)
Definition RModel.cxx:213
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:161
RModel & operator=(RModel &&other)
Definition RModel.cxx:39
void AddInputTensorName(std::string name)
Definition RModel.cxx:144
std::vector< std::string > fOutputTensorNames
Definition RModel.hxx:23
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:181
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, T *raw_data)
Definition RModel.hxx:68
void AddOperator(std::unique_ptr< ROperator > op, int order_execution=-1)
Definition RModel.cxx:148
RModel()=default
Default constructor.
void HeadInitializedTensors(std::string name, int n_print=50)
Definition RModel.cxx:955
void Initialize(int batchSize=-1, bool verbose=false)
Definition RModel.cxx:274
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
bool IsInputTensor(const std::string &name) const
Definition RModel.cxx:190
long WriteInitializedTensorsToFile(std::string filename="")
Definition RModel.cxx:757
void Generate(std::underlying_type_t< Options > options, int batchSize=-1, long pos=0, bool verbose=false)
Definition RModel.cxx:572
RModel & operator=(const RModel &other)=delete
std::unordered_map< std::string, InputTensorInfo > fInputTensorInfos
Definition RModel.hxx:16
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:257
void SetFilename(std::string filename)
Definition RModel.hxx:123
std::unordered_map< std::string, std::string > fShapeParams
Definition RModel.hxx:22
void SetNotWritableInitializedTensor(const std::string &tensor_name)
Definition RModel.cxx:266
std::vector< std::string > fInputTensorNames
Definition RModel.hxx:24
const std::vector< std::string > & GetInputTensorNames() const
Definition RModel.hxx:113
std::unordered_map< std::string, InitializedTensor > fInitializedTensors
Definition RModel.hxx:18
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:248
std::vector< std::string > GetOutputTensorNames()
Definition RModel.hxx:122
void UpdateOutputTensorList(std::vector< std::string > curr_output_tensor, std::vector< std::string > modify_output_tensor)
Definition RModel.cxx:241
RModel(std::string name, std::string parsedtime)
Definition RModel.hxx:43
void Generate(Options options=Options::kDefault, int batchSize=-1, int pos=0, bool verbose=false)
Definition RModel.hxx:108
create variable transformations