Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR
2#define TMVA_SOFIE_ROPERATOR
3
4#include <vector>
5#include <memory>
6
8//#include "RModel.hxx"
9
10
11
12namespace TMVA{
13namespace Experimental{
14namespace SOFIE{
15
16class RModel;
17
19
20
21public:
22 virtual std::vector<std::string> GetBlasRoutines() { return {}; }
23 virtual std::vector<std::string> GetStdLibs() { return {}; }
24 virtual std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>>) { return {}; };
25 virtual std::vector<ETensorType> TypeInference(std::vector<ETensorType>) { return {}; };
26 virtual void Initialize(RModel&) = 0;
27 virtual std::string Generate(std::string OpName) = 0; //expect unique opName for each operator within the same RModel
28 // generate code for Session constructor before tensor allocation
29 virtual std::string GenerateSessionCtorCode() { return "";}
30 // generate initialization code for session constructor after tensor allocations
31 virtual std::string GenerateInitCode() { return "";}
32 // generate some specific declaration code for Session
33 virtual std::string GenerateDeclCode() { return "";}
34 // generate session data members specific to operator
35 virtual std::string GenerateSessionMembersCode(std::string /*opName*/) { return ""; }
36 virtual std::string Header() { return "";}
37
38 //virtual void Forward_reference() = 0;
39 //virtual void Forward_blas() = 0;
40 virtual ~ROperator(){}
41
42protected:
43
44 const std::string SP = " "; ///< space used to correctly indent the generated C++ code
45 bool fUseSession = false; ///< flag to identify if using the session class
46 bool fIsOutputConstant = false; ///< flag to identify if operator has a constant output (no need to generate code)
47 bool fIsOutputParamShape = false; ///< flag to identify of the output represents a parametric shape (can be knwon at compile time)
48
49 mutable std::vector<std::string_view> fInputTensorNames;
50 mutable std::vector<std::string_view> fOutputTensorNames;
51
52public:
53 std::span<const std::string_view> GetOpInputTensors() const {
54 return fInputTensorNames;
55 }
56
57 std::span<const std::string_view> GetOpOutputTensors() const {
58 return fOutputTensorNames;
59 }
60
61};
62
63
64
65}//SOFIE
66}//Experimental
67}//TMVA
68
69
70#endif //TMVA_SOFIE_OPERATOR
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:49
virtual std::vector< std::string > GetBlasRoutines()
Definition ROperator.hxx:22
virtual void Initialize(RModel &)=0
bool fIsOutputParamShape
flag to identify of the output represents a parametric shape (can be knwon at compile time)
Definition ROperator.hxx:47
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:46
virtual std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >)
Definition ROperator.hxx:24
virtual std::string GenerateInitCode()
Definition ROperator.hxx:31
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:44
virtual std::vector< ETensorType > TypeInference(std::vector< ETensorType >)
Definition ROperator.hxx:25
virtual std::string GenerateSessionMembersCode(std::string)
Definition ROperator.hxx:35
std::span< const std::string_view > GetOpInputTensors() const
Definition ROperator.hxx:53
bool fUseSession
flag to identify if using the session class
Definition ROperator.hxx:45
virtual std::string Generate(std::string OpName)=0
std::span< const std::string_view > GetOpOutputTensors() const
Definition ROperator.hxx:57
virtual std::string GenerateDeclCode()
Definition ROperator.hxx:33
virtual std::vector< std::string > GetStdLibs()
Definition ROperator.hxx:23
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:50
virtual std::string GenerateSessionCtorCode()
Definition ROperator.hxx:29
create variable transformations