1#ifndef TMVA_SOFIE_ROperator_Where
2#define TMVA_SOFIE_ROperator_Where
54 ROperator_Where(
const std::string & nameC,
const std::string & nameX,
const std::string & nameY,
const std::string & nameZ):
68 auto ret = std::vector<std::vector<size_t>>(1,
input[0]);
75 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op Input Tensor ") +
fNX +
"is not found in model");
78 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op Input Tensor ") +
fNY +
"is not found in model");
81 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op Input Tensor ") +
fNC +
"is not found in model");
90 int dynamicInputs = 0;
116 if (dynamicInputs & 1)
118 if (dynamicInputs & 2)
120 if (dynamicInputs & 4)
127 if (dynamicInputs == 0) {
135 bool broadcastX =
false, broadcastY =
false, broadcastC =
false;
136 if (lengthX >= lengthY && lengthX >= lengthC) {
139 broadcastY = (lengthY != lengthX);
140 broadcastC = (lengthC != lengthX);
141 }
else if (lengthY >= lengthX && lengthY >= lengthC) {
144 broadcastX = (lengthX != lengthY);
145 broadcastC = (lengthC != lengthY);
146 }
else if (lengthC >= lengthX && lengthC >= lengthY) {
149 broadcastX = (lengthX != lengthC);
150 broadcastY = (lengthY != lengthC);
158 std::shared_ptr<void> broadcastedData(
160 std::default_delete<T[]>());
177 std::shared_ptr<void> broadcastedData(
179 std::default_delete<T[]>());
197 std::shared_ptr<void> broadcastedData(
199 std::default_delete<T[]>());
221 std::vector<Dim> shapeDataX;
222 std::vector<Dim> shapeDataY;
238 std::vector<T> dataZ;
239 std::vector<Dim> shapeDataZ;
244 bool isOutputConstantTensor =
true;
245 if (dataX && dataY) {
247 for (
size_t i = 0; i < dataZ.size(); i++)
248 dataZ[i] = (dataC[i]) ? dataX[i] : dataY[i];
251 }
else if (dataX && shapeDataY.size() > 0) {
253 for (
size_t i = 0; i < shapeDataZ.size(); i++) {
254 shapeDataZ[i] = (dataC[i]) ?
Dim{size_t(dataX[i])} : shapeDataY[i];
255 isOutputConstantTensor &= !shapeDataZ[i].isParam;
259 << isOutputConstantTensor << std::endl;
260 }
else if (dataY && shapeDataX.size() > 0) {
262 for (
size_t i = 0; i < shapeDataZ.size(); i++) {
263 shapeDataZ[i] = (dataC[i]) ? shapeDataY[i] :
Dim{size_t(dataY[i])};
264 isOutputConstantTensor &= !shapeDataZ[i].isParam;
268 << isOutputConstantTensor << std::endl;
269 }
else if (shapeDataY.size() > 0 && shapeDataX.size() > 0) {
271 for (
size_t i = 0; i < shapeDataZ.size(); i++) {
272 shapeDataZ[i] = (dataC[i]) ? shapeDataX[i] : shapeDataY[i];
273 isOutputConstantTensor &= !shapeDataZ[i].isParam;
281 if (dataZ.size() > 0)
283 else if (shapeDataZ.size() > 0)
291 << ((dataZ.size() > 0) ?
" (constant)" :
" (shape)") << std::endl;
321 auto IsInputDimParam = [&](
const std::string &p) {
324 if (s.isParam && s.param == p)
return true;
327 for (
size_t i = 0; i <
fDimShapeZ.size(); i++) {
329 if (s.isParam && s.param.find(
"std::max") != std::string::npos) {
364 std::stringstream out;
368 std::string
Generate(std::string opName)
override {
370 opName =
"op_" + opName;
371 std::stringstream out;
383 out <<
SP <<
"if (" << lengthX <<
" != " << lengthY <<
" || "
384 << lengthX <<
" != " << lengthC <<
") {\n";
385 for (
size_t i = 0; i <
fDimShapeZ.size(); i++) {
391 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast A dim " << i <<
" in " << opName <<
"\");\n";
398 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast B dim " << i <<
" in " << opName <<
"\");\n";
405 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast C dim " << i <<
" in " << opName <<
"\");\n";
419 auto buildIdxExpr = [&](
const std::vector<Dim> &dimShape,
420 const std::vector<Dim> &strides,
421 size_t rankZ) -> std::string {
422 if (dimShape.empty() ||
423 std::all_of(dimShape.begin(), dimShape.end(),
424 [](
Dim d) { return d.dim == 1 || d.GetVal() ==
"1"; }))
427 size_t offset = rankZ - dimShape.size();
428 for (
size_t i = 0; i < dimShape.size(); ++i) {
429 if (dimShape[i].dim == 1 || dimShape[i].GetVal() ==
"1")
continue;
430 expr +=
"idx_" + std::to_string(i +
offset);
431 if (strides[i].GetVal() !=
"1")
432 expr +=
" * " + strides[i].GetVal();
435 if (expr.size() >= 3)
436 for (
int j = 0; j < 3; j++) expr.pop_back();
437 return expr.empty() ?
"0" : expr;
450 [](
Dim d) { return d.dim == 1 || d.GetVal() ==
"1"; })) {
453 for (
size_t i = 0; i <
fDimShapeZ.size(); ++i) {
456 for (
int j = 0; j < nloop; j++) out <<
SP;
457 out <<
"for (size_t idx_" << i <<
" = 0; idx_" << i
458 <<
" < " <<
fDimShapeZ[i] <<
"; ++idx_" << i <<
") {\n";
459 idxZ +=
"idx_" + std::to_string(i);
460 if (stridesZ[i].GetVal() !=
"1")
461 idxZ +=
" * " + stridesZ[i].GetVal();
465 if (idxZ.size() >= 3)
466 for (
int j = 0; j < 3; j++) idxZ.pop_back();
470 for (
int j = 0; j < nloop + 1; j++) out <<
SP;
471 out <<
"tensor_" <<
fNZ <<
"[" << idxZ <<
"] = "
472 <<
"tensor_" <<
fNC <<
"[" << idxC <<
"] ? "
473 <<
"tensor_" <<
fNX <<
"[" << idxX <<
"] : "
474 <<
"tensor_" <<
fNY <<
"[" << idxY <<
"];\n";
477 for (
int i = nloop; i > 0; i--) {
478 for (
int j = 0; j < i; j++) out <<
SP;
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
std::vector< size_t > GetTensorShape(const std::string &name) const
std::vector< Dim > GetDimTensorShape(const std::string &name) const
bool IsDynamicTensor(const std::string &name) const
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
bool CheckIfTensorAlreadyExist(std::string tensor_name)
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
bool IsShapeTensor(const std::string &name) const
check if a tensor is a shape tensor
bool IsInitializedTensor(const std::string &name) const
std::vector< Dim > GetDynamicTensorShape(const std::string &name) const
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
void SetNotWritableInitializedTensor(const std::string &tensor_name)
ETensorType GetTensorType(std::string name) const
const std::vector< std::string > & GetInputTensorNames() const
const std::vector< Dim > & GetShapeTensorValues(const std::string &tensor_name) const
bool IsReadyInputTensor(const std::string &name) const
void AddShapeTensor(const std::string &name, const std::vector< Dim > &shapeValues, bool scalar=false)
std::vector< Dim > fDimShapeX
std::vector< size_t > fShapeY
std::vector< size_t > fShapeX
std::vector< size_t > fShapeC
std::string fNBroadcastedX
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
std::string Generate(std::string opName) override
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< Dim > fDimShapeY
ROperator_Where(const std::string &nameC, const std::string &nameX, const std::string &nameY, const std::string &nameZ)
void Initialize(RModel &model) override
std::vector< Dim > fDimShapeZ
std::string fNBroadcastedC
std::string fNBroadcastedY
std::vector< Dim > fDimShapeC
std::vector< size_t > fShapeZ
std::string GenerateInitCode() override
std::vector< std::string_view > fInputTensorNames
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
bool AreSameShape(const std::vector< size_t > &, const std::vector< size_t > &)
std::vector< size_t > MultidirectionalBroadcastShape(std::vector< std::vector< size_t > >)
T * UnidirectionalBroadcast(const T *data, const std::vector< size_t > &shape, const std::vector< size_t > &targetShape)
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::size_t ConvertShapeToLength(const std::vector< size_t > &shape)
std::string ConvertValuesToString(size_t n, const T *data, size_t maxprint=-1)
std::vector< Dim > ConvertShapeToDim(const std::vector< size_t > &shape)
Convert shape from integer format to dynamic one (based on Dim)
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations