Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Comparision.hxx
Go to the documentation of this file.
1
2#ifndef TMVA_SOFIE_ROperator_Comparision
3#define TMVA_SOFIE_ROperator_Comparision
4
6#include "TMVA/ROperator.hxx"
7#include "TMVA/RModel.hxx"
8
9#include <sstream>
10
11namespace TMVA{
12namespace Experimental{
13namespace SOFIE{
14
16
17template <typename T, EComparisionOperator Op1>
19
20template <typename T>
21struct ComparisionTrait<T, Eq> {
22 static const std::string Name() { return "Equal"; }
23 static std::string Op(const std::string & t1, const std::string t2) { return t1 + "==" + t2 + "? true : false "; }
24};
25
26template <typename T>
28 static const std::string Name() { return "Less"; }
29 static std::string Op(const std::string & t1, const std::string t2) { return t1 + "<" + t2 + "? true : false "; }
30};
31
32template <typename T>
34 static const std::string Name() { return "LessOrEqual"; }
35 static std::string Op(const std::string & t1, const std::string t2) { return t1 + "<=" + t2 + "? true : false "; }
36};
37
38template <typename T>
40 static const std::string Name() { return "Greater"; }
41 static std::string Op(const std::string & t1, const std::string t2) { return t1 + ">" + t2 + "? true : false "; }
42};
43
44template <typename T>
46 static const std::string Name() { return "GreaterOrEqual"; }
47 static std::string Op(const std::string & t1, const std::string t2) { return t1 + ">=" + t2 + "? true : false " ; }
48};
49
50template<typename T, EComparisionOperator Op>
51class ROperator_Comparision final : public ROperator{
52private:
53
54 std::string fNX1;
55 std::string fNX2;
56 std::string fNY;
57 std::vector<size_t> fShapeX1;
58 std::vector<size_t> fShapeX2;
59 std::vector<size_t> fShapeY;
60 std::string fNBroadcastedX1;
61 std::string fNBroadcastedX2;
62 bool fBroadcast = false;
63
64
65public:
67 ROperator_Comparision(std::string nameX1, std::string nameX2, std::string nameY):
68 fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
69
70 // type of output given input
71 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
72 return input;
73 }
74
75 // shape of output tensors given input tensors
76 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
77 auto ret = input; // return vector size 1 with first input
78 return ret;
79 }
80
81 void Initialize(RModel& model) override {
82 // input must be a graph input, or already initialized intermediate tensor
84 throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX1 + "is not found in model");
85 }
86 if (!model.CheckIfTensorAlreadyExist(fNX2)) {
87 throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX2 + "is not found in model");
88 }
91 bool broadcast = !UTILITY::AreSameShape(fShapeX1, fShapeX2);
92 if (broadcast) {
93 // Y is the common shape of A and B
95 bool broadcastX1 = !UTILITY::AreSameShape(fShapeX1, fShapeY);
96 bool broadcastX2 = !UTILITY::AreSameShape(fShapeX2, fShapeY);
97 // Broadcast A to Y
98 if (broadcastX1) {
99 if (model.IsInitializedTensor(fNX1)) {
100 auto data = model.GetInitializedTensorData(fNX1);
101 std::shared_ptr<void> broadcastedData(
102 UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeX1, fShapeY),
103 std::default_delete<float[]>());
104 // Update the data and the shape of A
105 model.UpdateInitializedTensor(fNX1, model.GetTensorType(fNX1), fShapeY, broadcastedData);
107 } else {
108 // Add an intermediate tensor for broadcasting A
109 fNBroadcastedX1 = "Broadcasted" + fNX1;
111 }
112 }
113 // Broadcast B to Y
114 if (broadcastX2) {
115 if (model.IsInitializedTensor(fNX2)) {
116 auto data = model.GetInitializedTensorData(fNX2);
117 std::shared_ptr<void> broadcastedData(
118 UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeX2, fShapeY),
119 std::default_delete<float[]>());
120 // Update the data and the shape of B
121 model.UpdateInitializedTensor(fNX2, model.GetTensorType(fNX2), fShapeY, broadcastedData);
123 } else {
124 // Add an intermediate tensor for broadcasting B
125 fNBroadcastedX2 = "Broadcasted" + fNX2;
127 }
128 }
129 } else {
131 }
133 }
134
135 std::string Generate(std::string OpName) override {
136 OpName = "op_" + OpName;
137
138 if (fShapeY.empty()) {
139 throw std::runtime_error("TMVA SOFIE Comparision Op called to Generate without being initialized first");
140 }
141 std::stringstream out;
142 out << SP << "\n//------ " << ComparisionTrait<T,Op>::Name() << "\n";
144 // Broadcast A if it's uninitialized
145 if (!fNBroadcastedX1.empty()) {
146 out << SP << "// Broadcasting uninitialized tensor " << fNX1 << "\n";
147 out << SP << "{\n";
148 out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNX1 << ", " << ConvertShapeToString(fShapeX1) << ", " << ConvertShapeToString(fShapeY) << ");\n";
149 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX1 << ");\n";
150 out << SP << SP << "delete[] data;\n";
151 out << SP << "}\n";
152 }
153 // Broadcast B if it's uninitialized
154 if (!fNBroadcastedX2.empty()) {
155 out << SP << "// Broadcasting uninitialized tensor " << fNX2 << "\n";
156 out << SP << "{\n";
157 out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNX2 << ", " << ConvertShapeToString(fShapeX2) << ", " << ConvertShapeToString(fShapeY) << ");\n";
158 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX2 << ");\n";
159 out << SP << SP << "delete[] data;\n";
160 out << SP << "}\n";
161 }
162 const std::string& nameX1 = fNBroadcastedX1.empty()? fNX1 : fNBroadcastedX1;
163 const std::string& nameX2 = fNBroadcastedX2.empty()? fNX2 : fNBroadcastedX2;
164
165 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
166 out << SP << SP << "fTensor_" << fNY << "[id] = " << ComparisionTrait<T,Op>::Op( "tensor_" + nameX1 + "[id]" , "tensor_" + nameX2 + "[id]") << " ;\n";
167 out << SP << "}\n";
168
169 return out.str();
170 }
171
172};
173
174}//SOFIE
175}//Experimental
176}//TMVA
177
178
179#endif //TMVA_SOFIE_ROperator_Comparision
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:91
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:196
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:116
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:181
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:257
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:248
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
ROperator_Comparision(std::string nameX1, std::string nameX2, std::string nameY)
std::string Generate(std::string OpName) override
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:41
bool AreSameShape(const std::vector< size_t > &, const std::vector< size_t > &)
std::vector< size_t > UnidirectionalBroadcastShape(std::vector< size_t >, std::vector< size_t >)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations
static std::string Op(const std::string &t1, const std::string t2)
static std::string Op(const std::string &t1, const std::string t2)
static std::string Op(const std::string &t1, const std::string t2)
static std::string Op(const std::string &t1, const std::string t2)
static std::string Op(const std::string &t1, const std::string t2)
auto * t1
Definition textangle.C:20