Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Custom.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_Custom
2#define TMVA_SOFIE_ROPERATOR_Custom
3
4
6#include "TMVA/ROperator.hxx"
7#include "TMVA/RModel.hxx"
8
9namespace TMVA{
10namespace Experimental{
11namespace SOFIE{
12
13
14template<typename T>
16{
17
18private:
19 std::string fOpName;
20 std::vector<std::string> fInputNames;
21 std::vector<std::string> fOutputNames;
22 std::vector<std::vector<std::size_t>> fOutputShapes;
23 std::vector<std::size_t> fInputSizes;
24 std::string fHeaderName;
26
27public:
29 ROperator_Custom(std::string OpName, std::vector<std::string>Inputs, std::vector<std::string>Outputs, std::vector<std::vector<std::size_t>> OutputShapes, std::string HeaderName){
33 for(auto& it:Inputs){
34 fInputNames.emplace_back(UTILITY::Clean_name(it));
35 fInputTensorNames.emplace_back(fInputNames.back());
36 }
37 for(auto& it:Outputs){
38 fOutputNames.emplace_back(UTILITY::Clean_name(it));
39 fOutputTensorNames.emplace_back(fOutputNames.back());
40 }
41 }
42
43 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>>) {return {{}};};
44 std::vector<ETensorType> TypeInference(std::vector<ETensorType>){ return {};};
45
46 void Initialize(RModel& model) override {
49
50 for(auto& it:fInputNames){
51 if (model.CheckIfTensorAlreadyExist(it) == false){
52 throw std::runtime_error("TMVA SOFIE Custom " + fOpName + " Op Input Tensor " + it + " is not found in model");
53 }
55 }
56
57 if(fOutputNames.size() != fOutputShapes.size()){
58 throw std::runtime_error("TMVA SOFIE Custom "+ fOpName + " Op was not intialized with the names/shapes of all the output tensors");
59 }
60
61 for(long unsigned int i=0; i<fOutputNames.size(); ++i){
63 }
64
65
67
68 if (model.Verbose()) {
69 std::cout << "Custom operator using " << fHeaderName;
70 for (auto & i : fInputNames) std::cout << " " << i;
71 std::cout << " ---> ";
72 for (auto & i : fOutputNames) std::cout << " " << i;
73 std::cout << "\n";
74 }
75 model.AddNeededCustomHeader("ROOT/RSpan.hxx");
76 }
77
78 std::string Generate(std::string OpName){
79 OpName = "op_" + OpName;
80 std::stringstream out;
81 out << "\n//------ "<<fOpName<<" \n";
82 std::string args;
83 for(long unsigned int i = 0; i<fInputNames.size(); ++i){
84 args+="std::span<const "+ConvertTypeToString(fInputType)+">(tensor_"+std::string(fInputNames[i])+", "+fInputSizes[i]+"),";
85 }
86
87 for(long unsigned int i = 0; i<fOutputNames.size(); ++i){
88 args+="std::span<"+TensorType<T>::Name()+">(tensor_"+std::string(fOutputNames[i])+", "+ConvertShapeToLength(fOutputShapes[i])+"),";
89 }
90 args.pop_back();
91 out << SP << fOpName<<"::Compute("+args+");\n";
92 return out.str();
93 }
94
95};
96
97
98}//SOFIE
99}//Experimental
100}//TMVA
101
102
103#endif //TMVA_SOFIE_ROPERATOR_Custom
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void AddNeededCustomHeader(std::string filename)
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:94
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:227
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
void UpdateOutputTensorList(std::vector< std::string > curr_output_tensor, std::vector< std::string > modify_output_tensor)
Definition RModel.cxx:272
ROperator_Custom(std::string OpName, std::vector< std::string >Inputs, std::vector< std::string >Outputs, std::vector< std::vector< std::size_t > > OutputShapes, std::string HeaderName)
std::string Generate(std::string OpName)
std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
std::vector< std::vector< std::size_t > > fOutputShapes
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:46
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:42
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:47
std::string Clean_name(std::string input_tensor_name)
std::string ConvertTypeToString(ETensorType type)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations