1#ifndef TMVA_SOFIE_ROPERATOR_ScatterND
2#define TMVA_SOFIE_ROPERATOR_ScatterND
13namespace Experimental{
52 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterND Op Input Tensor ") +
fNX +
"is not found in model");
55 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterND Op Input Tensor ") +
fNI +
"is not found in model");
58 throw std::runtime_error(std::string(
"TMVA SOFIE ScatterND Op Input Tensor ") +
fNU +
"is not found in model");
71 if (!(
fShapeI.back().isParam) ) {
72 const size_t k =
fShapeI.back().dim;
75 throw std::invalid_argument(
76 "ScatterND: last dim of indices (" + std::to_string(k) +
77 ") must be <= rank of data (" + std::to_string(
r) +
")");
82 throw std::invalid_argument(
"ScatterND: updates rank mismatch");
85 throw std::runtime_error(
"TMVA SOFIE ScatterND : Index_shape(-1) is not known. This case is not supported");
103 return "//---------------------------------------\n";
106 std::stringstream out;
130 out <<
SP <<
"// Step 1: copy input data to output\n";
131 out <<
SP <<
"std::copy(tensor_" <<
fNX <<
", tensor_" <<
fNX <<
" + " <<
data_length <<
", tensor_" <<
fNY <<
");\n";
134 out <<
SP <<
"// Step 2: data strides (row-major)\n";
136 out <<
SP <<
"size_t " <<
opName <<
"_data_strides[" <<
r <<
"] = {";
137 for (
size_t i = 0; i <
r; ++i)
138 out <<
stridesX[i] << (i + 1 <
r ?
", " :
"");
142 out <<
SP <<
"// Step 3: scatter updates into output\n";
146 out <<
SP <<
SP <<
"int64_t data_offset = 0;\n";
147 for (
size_t dim = 0; dim < k; ++dim) {
148 out <<
SP <<
SP <<
"{\n";
149 out <<
SP <<
SP <<
SP <<
"int64_t coord = tensor_" <<
fNI
150 <<
"[idx * " << k <<
" + " << dim <<
"];\n";
152 out <<
SP <<
SP <<
SP <<
"if (coord < 0) coord += " <<
fShapeX[dim] <<
";\n";
153 out <<
SP <<
SP <<
SP <<
"data_offset += coord * "
154 <<
opName <<
"_data_strides[" << dim <<
"];\n";
155 out <<
SP <<
SP <<
"}\n";
159 out <<
SP <<
SP <<
"for (int64_t s = 0; s < " <<
slice_size <<
"; s++) {\n";
160 out <<
SP <<
SP <<
SP <<
"auto upd = tensor_" <<
fNU
164 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY <<
"[data_offset + s] = upd;\n";
166 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY<<
"[data_offset + s] += upd;\n";
168 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY <<
"[data_offset + s] *= upd;\n";
170 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY<<
"[data_offset + s] = "
171 <<
"std::min(tensor_" <<
fNY <<
"[data_offset + s], upd);\n";
173 out <<
SP <<
SP <<
SP <<
"tensor_" <<
fNY <<
"[data_offset + s] = "
174 <<
"std::max(tensor_" <<
fNY <<
"[data_offset + s], upd);\n";
176 throw std::runtime_error(
177 "TMVA SOFIE ScatterND: unsupported reduction '" +
fReduction +
"'");
180 out <<
SP <<
SP <<
"}\n";
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
std::vector< Dim > GetDimTensorShape(const std::string &name) const
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
bool CheckIfTensorAlreadyExist(std::string tensor_name)
ETensorType GetTensorType(std::string name) const
std::vector< Dim > fShapeI
std::string Generate(std::string opName) override
std::vector< Dim > fShapeX
ROperator_ScatterND(const std::string &nameX, const std::string &nameI, const std::string &nameU, const std::string &nameY, std::string reduction)
void Initialize(RModel &model) override
std::vector< int64_t > fIndices
std::vector< Dim > fShapeY
std::vector< std::string_view > fInputTensorNames
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
const std::string SP
space used to correctly indent the generated C++ code
std::vector< std::string_view > fOutputTensorNames
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
create variable transformations