Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_BasicNary.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_BASICNARY
2#define TMVA_SOFIE_ROPERATOR_BASICNARY
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <vector>
9#include <sstream>
10#include <algorithm>
11
12namespace TMVA{
13namespace Experimental{
14namespace SOFIE{
15
17
18template<typename T, EBasicNaryOperator Op>
20
21template<typename T>
23 static const std::string Name() {return "Max";}
24 static std::string Op(const std::string& res, std::vector<std::string>& inputs) {
25 std::stringstream out;
26 out << "\t" << "\t" << res << " = " << inputs[0] << ";\n";
27 for (size_t i = 1; i < inputs.size(); i++) {
28 out << "\t" << "\t" << res << " = std::max(" << res << ", " << inputs[i] << ");\n";
29 }
30 return out.str();
31 }
32};
33
34template<typename T>
36 static const std::string Name() {return "Min";}
37 static std::string Op(const std::string& res, std::vector<std::string>& inputs) {
38 std::stringstream out;
39 out << "\t" << "\t" << res << " = " << inputs[0] << ";\n";
40 for (size_t i = 1; i < inputs.size(); i++) {
41 out << "\t" << "\t" << res << " = std::min(" << res << ", " << inputs[i] << ");\n";
42 }
43 return out.str();
44 }
45};
46
47template<typename T>
49
50template<>
52 static const std::string Name() {return "Mean";}
53 static std::string Op(const std::string& res, std::vector<std::string>& inputs) {
54 std::stringstream out;
55 out << "\t" << "\t" << res << " = (" << inputs[0];
56 for (size_t i = 1; i < inputs.size(); i++) {
57 out << " + " << inputs[i];
58 }
59 out << ") / float(" << inputs.size() << ");\n";
60 return out.str();
61 }
62};
63
64template<typename T>
66 static const std::string Name() {return "Sum";}
67 static std::string Op(const std::string& res, std::vector<std::string>& inputs) {
68 std::stringstream out;
69 out << "\t" << "\t" << res << " = " << inputs[0];
70 for (size_t i = 1; i < inputs.size(); i++) {
71 out << " + " << inputs[i];
72 }
73 out << ";\n";
74 return out.str();
75 }
76};
77
78template <typename T, EBasicNaryOperator Op>
80{
81
82private:
83
84 std::vector<std::string> fNInputs;
85 std::string fNY;
86 std::vector<std::vector<size_t>> fShapeInputs;
87
88 std::vector<std::string> fNBroadcastedInputs;
89 std::vector<size_t> fShapeY;
90
91 bool fBroadcast = false;
92
93 std::string fType;
94
95public:
97
98 ROperator_BasicNary( const std::vector<std::string> & inputNames, const std::string& nameY):
99 fNY(UTILITY::Clean_name(nameY)){
100 fNInputs.reserve(inputNames.size());
101 for (auto & name : inputNames)
103
104 fInputTensorNames.resize(fNInputs.size());
105 std::transform(fNInputs.begin(), fNInputs.end(), fInputTensorNames.begin(),
106 [](const std::string& s) -> std::string_view { return s; });
108 }
109
110 // type of output given input
111 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
112 return input;
113 }
114
115 // shape of output tensors given input tensors
116 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
117 auto ret = std::vector<std::vector<size_t>>(1, input[0]);
118 return ret;
119 }
120
121 void Initialize(RModel& model) override {
122 for (auto &it : fNInputs) {
123 if (!model.CheckIfTensorAlreadyExist(it)) {
124 throw std::runtime_error("TMVA SOFIE BasicNary Op Input Tensor " + it + " is not found in model");
125 }
126 fShapeInputs.push_back(model.GetTensorShape(it));
127 }
128 // Find the common shape of the input tensors
131 // Broadcasting
132 size_t N = fNInputs.size();
133 fNBroadcastedInputs.reserve(N);
134 for (size_t i = 0; i < N; i++) {
136 fBroadcast = true;
137 std::string name = "Broadcasted" + fNInputs[i];
139 fNBroadcastedInputs.emplace_back("tensor_" + name);
140 } else {
141 fNBroadcastedInputs.emplace_back("tensor_" + fNInputs[i]);
142 }
143 }
145 }
146
147 std::string Generate(std::string OpName){
148 OpName = "op_" + OpName;
149 if (fShapeY.empty()) {
150 throw std::runtime_error("TMVA SOFIE BasicNary called to Generate without being initialized first");
151 }
152 std::stringstream out;
154 out << SP << "\n//------ BasicNary operator\n";
155 if (fBroadcast) {
156 for (size_t i = 0; i < fNInputs.size(); i++) {
157 if (fNBroadcastedInputs[i] != fNInputs[i]) {
158 out << SP << SP << "// Broadcasting " << fNInputs[i] << " to " << ConvertShapeToString(fShapeY) << "\n";
159 out << SP << SP << "{\n";
160 out << SP << SP << SP << fType << "* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << fType << ">(tensor_" + fNInputs[i] << ", " << ConvertShapeToString(fShapeInputs[i]);
161 out << ", " << ConvertShapeToString(fShapeY) << ");\n";
162 out << SP << SP << SP << "std::copy(data, data + " << length << ", " << fNBroadcastedInputs[i] << ");\n";
163 out << SP << SP << SP << "delete[] data;\n";
164 out << SP << SP << "}\n";
165 }
166 }
167 }
168
169 if (fNInputs.size() == 1) {
170 out << SP << "std::copy(tensor_" << fNInputs[0] << ", tensor_" << fNInputs[0] << " + ";
171 out << length << ", tensor_" << fNY << ");\n";
172 } else {
173 std::vector<std::string> inputs(fNBroadcastedInputs.size());
174 for (size_t i = 0; i < fNBroadcastedInputs.size(); i++) {
175 inputs[i] = fNBroadcastedInputs[i] + "[id]";
176 }
177 out << SP << "for (size_t id = 0; id < " << length << "; id++) {\n";
178 out << NaryOperatorTraits<T,Op>::Op("tensor_" + fNY + "[id]", inputs);
179 out << SP << "}\n";
180 }
181 return out.str();
182 }
183
184 std::vector<std::string> GetStdLibs() {return { std::string("cmath") }; }
185};
186
187}//SOFIE
188}//Experimental
189}//TMVA
190
191
192#endif //TMVA_SOFIE_ROPERATOR_BasicNary
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
#define N
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
char name[80]
Definition TGX11.cxx:110
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:94
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:227
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
std::vector< std::vector< size_t > > fShapeInputs
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
ROperator_BasicNary(const std::vector< std::string > &inputNames, const std::string &nameY)
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:46
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:42
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:47
bool AreSameShape(const std::vector< size_t > &, const std::vector< size_t > &)
std::string Clean_name(std::string input_tensor_name)
std::vector< size_t > MultidirectionalBroadcastShape(std::vector< std::vector< size_t > >)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::string ConvertTypeToString(ETensorType type)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations
static std::string Op(const std::string &res, std::vector< std::string > &inputs)
static std::string Op(const std::string &res, std::vector< std::string > &inputs)
static std::string Op(const std::string &res, std::vector< std::string > &inputs)
static std::string Op(const std::string &res, std::vector< std::string > &inputs)