2#ifndef TMVA_SOFIE_ROperator_Comparision
3#define TMVA_SOFIE_ROperator_Comparision
17template <
typename T, EComparisionOperator Op1>
22 static const std::string
Name() {
return "Equal"; }
23 static std::string
Op(
const std::string &
t1,
const std::string t2) {
return t1 +
" == " + t2; }
29 static const std::string
Name() {
return "Less"; }
30 static std::string
Op(
const std::string &
t1,
const std::string t2) {
return t1 +
" < " + t2; }
36 static const std::string
Name() {
return "LessOrEqual"; }
37 static std::string
Op(
const std::string &
t1,
const std::string t2) {
return t1 +
" <= " + t2; }
43 static const std::string
Name() {
return "Greater"; }
44 static std::string
Op(
const std::string &
t1,
const std::string t2) {
return t1 +
" > " + t2; }
50 static const std::string
Name() {
return "GreaterOrEqual"; }
51 static std::string
Op(
const std::string &
t1,
const std::string t2) {
return t1 +
" >= " + t2 ; }
55template<
typename T, EComparisionOperator Op>
96 if (!model.CheckIfTensorAlreadyExist(
fNX1)){
97 throw std::runtime_error(std::string(
"TMVA SOFIE Comparision Op Input Tensor ") +
fNX1 +
"is not found in model");
99 if (!model.CheckIfTensorAlreadyExist(
fNX2)) {
100 throw std::runtime_error(std::string(
"TMVA SOFIE Comparision Op Input Tensor ") +
fNX2 +
"is not found in model");
102 if (model.IsDynamicTensor(
fNX1))
108 if (model.IsDynamicTensor(
fNX2))
118 bool broadcastX1 =
false;
119 bool broadcastX2 =
false;
135 std::unique_ptr<T> broadcastedData1;
136 std::unique_ptr<T> broadcastedData2;
138 std::vector<Dim> shapeData1;
139 std::vector<Dim> shapeData2;
141 bool *outData =
new bool[
length];
142 if (model.IsInitializedTensor(
fNX1)) {
143 data1 =
static_cast<T *
>(model.GetInitializedTensorData(
fNX1).get());
145 broadcastedData1 = std::unique_ptr<T>(
147 data1 = broadcastedData1.get();
150 }
else if (model.IsShapeTensor(
fNX1)) {
151 shapeData1 = model.GetShapeTensorValues(
fNX1);
153 if (model.IsInitializedTensor(
fNX2)) {
154 data2 =
static_cast<T *
>(model.GetInitializedTensorData(
fNX2).get());
156 broadcastedData2 = std::unique_ptr<T>(
158 data2 = broadcastedData2.get();
160 }
else if (model.IsShapeTensor(
fNX2)) {
161 shapeData2 = model.GetShapeTensorValues(
fNX2);
163 if (data1 && data2) {
165 for (
size_t i = 0; i <
length; i++)
167 model.AddConstantTensor(
fNY,
fShapeY, outData);
169 std::cout << ComparisionTrait<T, Op>::Name() <<
" op ---> " <<
fNY <<
" "
172 }
else if ((data1 || !shapeData1.empty()) && (data2 || !shapeData2.empty())) {
174 if (data1 && !data2) {
176 for (
size_t i = 0; i <
length; i++) {
177 if (shapeData2[i].isParam) {
178 if (shapeData2[i].dim ==
size_t(-1) || data1[i] > 0) {
183 shapeData2[i].dim = 0;
188 }
else if (!data1 && data2) {
190 for (
size_t i = 0; i <
length; i++) {
191 if (shapeData1[i].isParam) {
192 if (shapeData1[i].dim ==
size_t(-1) || data2[i] > 0) {
197 shapeData1[i].dim = 0;
202 }
else if (!shapeData1.empty() && !shapeData2.empty()) {
204 for (
size_t i = 0; i <
length; i++) {
205 if (!shapeData1[i].isParam && !shapeData2[i].isParam) {
207 }
else if (shapeData1[i].isParam && shapeData2[i].isParam) {
208 if (shapeData1[i].param == shapeData2[i].param)
221 model.AddConstantTensor(
fNY,
fShapeY, outData);
223 std::cout << ComparisionTrait<T, Op>::Name() <<
" op ---> " <<
fNY <<
" "
225 <<
" (constant) " << std::endl;
234 std::cout << ComparisionTrait<T, Op>::Name() <<
" op ---> " <<
fNY <<
" "
248 auto IsInputDimParam = [&](
const std::string &p) {
249 auto inputNames = model.GetInputTensorNames();
250 for (
auto &
input : inputNames) {
251 for (
auto &i_s : model.GetDimTensorShape(
input)) {
252 if (i_s.isParam && i_s.param == p)
258 for (
size_t i = 0; i <
fDimShapeY.size(); i++) {
260 if (s.isParam && s.param.find(
"std::max") != std::string::npos) {
267 }
else if (IsInputDimParam(
fDimShapeX2[i].param)) {
278 if (model.Verbose()) {
282 model.PrintIntermediateTensors();
287 std::string
Generate(std::string opName)
override {
289 opName =
"op_" + opName;
292 throw std::runtime_error(
"TMVA SOFIE Comparision Op called to Generate without being initialized first");
294 std::stringstream out;
305 std::string compute_idx_X1, compute_idx_X2, compute_idx_Y;
308 compute_idx_X1 =
"0";
314 if (stridesA[i].GetVal() !=
"1")
315 compute_idx_X1 +=
" * " + stridesA[i].GetVal();
316 compute_idx_X1 +=
" + ";
319 for (
int j = 0; j < 3; j++)
320 compute_idx_X1.pop_back();
324 compute_idx_X2 =
"0";
330 if (stridesB[i].GetVal() !=
"1")
331 compute_idx_X2 +=
" * " + stridesB[i].GetVal();
332 compute_idx_X2 +=
" + ";
335 for (
int j = 0; j < 3; j++)
336 compute_idx_X2.pop_back();
343 for (
size_t i = 0; i <
fDimShapeY.size(); ++i) {
346 for (
int j = 0; j < nloop; j++) out <<
SP;
347 out <<
"for (size_t idx_" << i <<
" = 0; idx_" << i <<
" < " <<
fDimShapeY[i]
348 <<
"; ++idx_" << i <<
"){\n";
349 compute_idx_Y +=
"idx_" + std::to_string(i);
350 if (stridesY[i].GetVal() !=
"1")
351 compute_idx_Y +=
" * " + stridesY[i].GetVal();
352 compute_idx_Y +=
" + ";
356 for (
int j = 0; j < 3; j++)
357 compute_idx_Y.pop_back();
359 for (
int j = 0; j < nloop + 1; j++) out <<
SP;
360 out <<
"tensor_" <<
fNY <<
"[" << compute_idx_Y <<
"] = "
362 "tensor_" +
fNX2 +
"[" + compute_idx_X2 +
"]") <<
" ;\n";
365 for (
int i = nloop; i > 0; i--) {
366 for (
int j = 0; j < i; j++) out <<
SP;
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< size_t > fShapeX2
std::string Generate(std::string opName) override
std::vector< size_t > fShapeX1
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
std::vector< size_t > fShapeY
std::vector< Dim > fDimShapeY
ROperator_Comparision(const std::string &nameX1, const std::string &nameX2, const std::string &nameY)
void Initialize(RModel &model) override
std::vector< Dim > fDimShapeX2
std::vector< Dim > fDimShapeX1
std::vector< std::string_view > fInputTensorNames
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
bool AreSameShape(const std::vector< size_t > &, const std::vector< size_t > &)
std::vector< size_t > UnidirectionalBroadcastShape(std::vector< size_t > &, std::vector< size_t > &)
std::vector< size_t > MultidirectionalBroadcastShape(std::vector< std::vector< size_t > >)
T * UnidirectionalBroadcast(const T *data, const std::vector< size_t > &shape, const std::vector< size_t > &targetShape)
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::size_t ConvertShapeToLength(const std::vector< size_t > &shape)
std::string ConvertValuesToString(size_t n, const T *data, size_t maxprint=-1)
std::vector< Dim > ConvertShapeToDim(const std::vector< size_t > &shape)
Convert shape from integer format to dynamic one (based on Dim).
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations
static std::string Op(const std::string &t1, const std::string t2)
static bool Result(T v1, T v2)
static const std::string Name()
static const std::string Name()
static std::string Op(const std::string &t1, const std::string t2)
static bool Result(T v1, T v2)
static const std::string Name()
static bool Result(T v1, T v2)
static std::string Op(const std::string &t1, const std::string t2)
static bool Result(T v1, T v2)
static const std::string Name()
static std::string Op(const std::string &t1, const std::string t2)
static const std::string Name()
static bool Result(T v1, T v2)
static std::string Op(const std::string &t1, const std::string t2)