1#ifndef TMVA_SOFIE_ROperator_Where
2#define TMVA_SOFIE_ROperator_Where
51 const std::string &nameX,
52 const std::string &nameY,
53 const std::string &nameZ)
82 if (!model.CheckIfTensorAlreadyExist(
fNC))
83 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op: condition tensor ") +
fNC +
" not found in model");
84 if (!model.CheckIfTensorAlreadyExist(
fNX))
85 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op: X tensor ") +
fNX +
" not found in model");
86 if (!model.CheckIfTensorAlreadyExist(
fNY))
87 throw std::runtime_error(std::string(
"TMVA SOFIE Where Op: Y tensor ") +
fNY +
" not found in model");
90 if (model.IsReadyInputTensor(
fNC))
96 int dynamicInputs = 0;
98 if (model.IsDynamicTensor(
fNC)) {
105 if (model.IsDynamicTensor(
fNX)) {
112 if (model.IsDynamicTensor(
fNY)) {
120 if (model.Verbose()) {
121 if (dynamicInputs & 1)
123 if (dynamicInputs & 2)
125 if (dynamicInputs & 4)
132 if (dynamicInputs == 0) {
143 bool allConstant = model.IsInitializedTensor(
fNC) &&
144 model.IsInitializedTensor(
fNX) &&
145 model.IsInitializedTensor(
fNY);
151 auto broadcastIfNeeded = [&](
const std::string &
name,
152 const std::vector<size_t> &shape,
154 const std::string &prefix) {
156 bcName = prefix +
name +
"to" +
fNZ;
157 auto data = model.GetInitializedTensorData(
name);
158 std::shared_ptr<void> bcData(
160 std::default_delete<T[]>());
161 model.AddConstantTensor(bcName, model.GetTensorType(
name),
fShapeZ, bcData);
173 auto dataC =
static_cast<bool *
>(model.GetInitializedTensorData(nameC).get());
174 auto dataX =
static_cast<T *
> (model.GetInitializedTensorData(nameX).get());
175 auto dataY =
static_cast<T *
> (model.GetInitializedTensorData(nameY).get());
178 std::vector<T> dataZ(
len);
179 for (
size_t i = 0; i <
len; ++i)
180 dataZ[i] = dataC[i] ? dataX[i] : dataY[i];
182 model.AddConstantTensor<T>(
fNZ,
fShapeZ, dataZ.data());
183 model.SetNotWritableInitializedTensor(nameC);
184 model.SetNotWritableInitializedTensor(nameX);
185 model.SetNotWritableInitializedTensor(nameY);
198 model.AddIntermediateTensor(
fNZ, model.GetTensorType(
fNX),
fShapeZ);
221 auto IsInputDimParam = [&](
const std::string &p) {
222 for (
auto &
input : model.GetInputTensorNames())
223 for (
auto &s : model.GetDimTensorShape(
input))
224 if (s.isParam && s.param == p)
return true;
227 for (
size_t i = 0; i <
fDimShapeZ.size(); i++) {
229 if (s.isParam && s.param.find(
"std::max") != std::string::npos) {
252 std::stringstream out;
260 opName =
"op_" + opName;
263 throw std::runtime_error(
"TMVA SOFIE Where Op called to Generate without being initialized first");
266 std::stringstream out;
276 out <<
SP <<
"if (" << lengthX <<
" != " << lengthY <<
" || "
277 << lengthX <<
" != " << lengthC <<
") {\n";
278 for (
size_t i = 0; i <
fDimShapeZ.size(); i++) {
284 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast X dim " << i <<
" in " << opName <<
"\");\n";
291 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast Y dim " << i <<
" in " << opName <<
"\");\n";
298 <<
"throw std::runtime_error(\"SOFIE Where: cannot broadcast C dim " << i <<
" in " << opName <<
"\");\n";
314 auto buildIdxExpr = [&](
const std::vector<Dim> &dimShape,
315 const std::vector<Dim> &strides,
316 size_t rankZ) -> std::string {
317 if (dimShape.empty() ||
318 std::all_of(dimShape.begin(), dimShape.end(),
319 [](
Dim d) { return d.dim == 1 || d.GetVal() ==
"1"; }))
322 size_t offset = rankZ - dimShape.size();
323 for (
size_t i = 0; i < dimShape.size(); ++i) {
324 if (dimShape[i].dim == 1 || dimShape[i].GetVal() ==
"1")
continue;
325 expr +=
"idx_" + std::to_string(i +
offset);
326 if (strides[i].GetVal() !=
"1")
327 expr +=
" * " + strides[i].GetVal();
330 if (expr.size() >= 3)
331 for (
int j = 0; j < 3; j++) expr.pop_back();
332 return expr.empty() ?
"0" : expr;
344 [](
Dim d) { return d.dim == 1 || d.GetVal() ==
"1"; })) {
347 for (
size_t i = 0; i <
fDimShapeZ.size(); ++i) {
350 for (
int j = 0; j < nloop; j++) out <<
SP;
351 out <<
"for (size_t idx_" << i <<
" = 0; idx_" << i
352 <<
" < " <<
fDimShapeZ[i] <<
"; ++idx_" << i <<
") {\n";
353 idxZ +=
"idx_" + std::to_string(i);
354 if (stridesZ[i].GetVal() !=
"1")
355 idxZ +=
" * " + stridesZ[i].GetVal();
359 if (idxZ.size() >= 3)
360 for (
int j = 0; j < 3; j++) idxZ.pop_back();
364 for (
int j = 0; j < nloop + 1; j++) out <<
SP;
365 out <<
"tensor_" <<
fNZ <<
"[" << idxZ <<
"] = "
366 <<
"tensor_" <<
fNC <<
"[" << idxC <<
"] ? "
367 <<
"tensor_" <<
fNX <<
"[" << idxX <<
"] : "
368 <<
"tensor_" <<
fNY <<
"[" << idxY <<
"];\n";
371 for (
int i = nloop; i > 0; i--) {
372 for (
int j = 0; j < i; j++) out <<
SP;
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
std::vector< Dim > fDimShapeX
std::vector< size_t > fShapeY
std::vector< size_t > fShapeX
std::vector< size_t > fShapeC
std::string fNBroadcastedX
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
std::string Generate(std::string opName) override
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< Dim > fDimShapeY
ROperator_Where(const std::string &nameC, const std::string &nameX, const std::string &nameY, const std::string &nameZ)
void Initialize(RModel &model) override
std::vector< Dim > fDimShapeZ
std::string fNBroadcastedC
std::string fNBroadcastedY
std::vector< Dim > fDimShapeC
std::vector< size_t > fShapeZ
std::string GenerateInitCode() override
std::vector< std::string_view > fInputTensorNames
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
std::vector< size_t > MultidirectionalBroadcastShape(std::vector< std::vector< size_t > >)
T * UnidirectionalBroadcast(const T *data, const std::vector< size_t > &shape, const std::vector< size_t > &targetShape)
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::size_t ConvertShapeToLength(const std::vector< size_t > &shape)
std::string ConvertValuesToString(size_t n, const T *data, size_t maxprint=-1)
std::vector< Dim > ConvertShapeToDim(const std::vector< size_t > &shape)
Convert shape from integer format to dynamic one (based on Dim).
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
std::string ConvertShapeToString(const std::vector< size_t > &shape)
create variable transformations