Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Gather.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GATHER
2#define TMVA_SOFIE_ROPERATOR_GATHER
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9#include <stdexcept>
10#include <string>
11
12namespace TMVA{
13namespace Experimental{
14namespace SOFIE{
15
17{
18private:
19
20 int64_t fAttrAxis = 0;
21
22 std::string fNX;
23 std::string fNIndices;
24 std::string fNY;
25
26 std::vector<size_t> fShapeX;
27 std::vector<size_t> fShapeIndices;
28 std::vector<size_t> fShapeY;
29
30 std::vector<int64_t> fIndices; // indices vector in case they are known at initialization
31
32 std::string fType;
33
34public:
36 ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY):
37 fAttrAxis(attrAxis), fNX(UTILITY::Clean_name(nameX)), fNIndices(UTILITY::Clean_name(nameIndices)), fNY(UTILITY::Clean_name(nameY)) {
40 }
41
42 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
43 return input;
44 }
45
46 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
47 auto ret = input;
48 return ret;
49 }
50
51 void Initialize(RModel& model) override {
52 if (!model.CheckIfTensorAlreadyExist(fNX)) {
53 throw std::runtime_error("TMVA SOFIE Gather Op Input Tensor " + fNX + " is not found in model");
54 }
55 fShapeX = model.GetTensorShape(fNX);
57 size_t q = fShapeIndices.size();
58 // Axis in range [0, r) where r=rank(X)
59 size_t r = fShapeX.size();
60 // Set the axis
61 if (fAttrAxis < 0) {
62 fAttrAxis = fAttrAxis + int64_t(r);
63 }
64 // empty fShapeIndices is a scalar value for the indices
66
67 // case indices tensor is initialized
68 if (model.IsInitializedTensor(fNIndices)) {
69 int64_t* indicesData = static_cast<int64_t*>(model.GetInitializedTensorData(fNIndices).get());
70 //flag index tensor as not writable (not sure this is needed since index tensor might be used in generated code)
72 // update indices data in case of negative dim values
73 for (size_t i = 0; i < indicesLength; i++) {
74 if (indicesData[i] < 0) {
76 }
77 }
78 // Save in a vector gather Indices of size q
79 fIndices = std::vector<int64_t>(indicesData, indicesData + indicesLength);
80 }
81 // Output shape
82 if (model.Verbose())
83 std::cout << "Gather: q and r " << q << " " << r << " shape indices " << ConvertShapeToString(fShapeIndices) << std::endl;
84
85 if (fShapeY.empty()) {
86 fShapeY.resize(q + r - 1);
87 if (fAttrAxis > 0) {
88 // Copy shape of X[0, ..., axis) to Shape of Y[0, ..., axis)
89 std::copy(fShapeX.begin(), fShapeX.begin() + fAttrAxis, fShapeY.begin());
90 }
91 // Set shape of Y[axis, ..., axis + q)
92 for (size_t i = 0; i < q; i++) {
94 }
95 // Copy shape of X[axis + 1, ..., axis + r) to shape of Y[axis + q, ... q + r - 1)
96 std::copy(fShapeX.begin() + fAttrAxis + 1, fShapeX.end(), fShapeY.begin() + fAttrAxis + q);
97 }
98 // case input is known (type is an integer) and input indices is a scalar (or vector of size 1)
99 if (model.IsInitializedTensor(fNX) && q <= 1 && r == 1 && fIndices.size() > 0) {
100 if (model.GetTensorType(fNX) == ETensorType::INT64) {
101 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNX).get());
102 // if q <=1 and r = 1 output length = 1 (it is a scalar)
103 std::vector<int64_t> outputData(ConvertShapeToLength(fShapeY));
105 model.AddConstantTensor(fNY, fShapeY, outputData.data());
106 if (model.Verbose())
107 std::cout << "Gather: " << fNX << " " << ConvertShapeToString(fShapeX) << " -> " << fNY << " with shape " << ConvertShapeToString(fShapeY)
108 << " and values " << ConvertValuesToString(outputData) << " (constant) " << std::endl;
109 fIsOutputConstant = true;
110 }
111 }
112 if (!fIsOutputConstant) {
113 // Add output tensor
116 if (model.Verbose())
117 std::cout << "Gather: " << fNX << " " << ConvertShapeToString(fShapeX) << " -> " << fNY << " with shape " << ConvertShapeToString(fShapeY)
118 << std::endl;
119 }
120 }
121
122 std::string Generate(std::string OpName) override {
123 if (fIsOutputConstant) {
124 // no code to generate here for constant output. Tensor output is defined in Session constructor
125 return "//---------------------------------------\n";
126 }
127 OpName = "op_" + OpName;
128 std::stringstream out;
129 out << "//--------- Gather operator \n";
130 // The shape of the output is q + r - 1
131 size_t r = fShapeX.size();
132 // Indices of shape q
133 size_t q = fShapeIndices.size();
134 // Strides
135 std::vector<size_t> stridesX = UTILITY::ComputeStrideFromShape(fShapeX);
136 std::vector<size_t> stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
138
139 // case fIndices is not known we need to correct for negative axis indices at run-time
140 if (fIndices.empty()) {
142 out << SP << "// correct in case of negative gather indices\n";
143 out << SP << "for (size_t i = 0; i < " << indicesLength << "; i++){\n";
144 out << SP << SP << "if (tensor_" << fNIndices << "[i] < 0)\n";
145 out << SP << SP << SP << "tensor_" << fNIndices << "[i] += " << fShapeX[fAttrAxis] << ";\n";
146 out << SP << "}\n";
147 }
148
149
150 // Fill the output Y[j_0, j_1, ..., j_{axis - 1}, i_0, i_1, ..., i_{q - 1}, j_{axis + 1}, ..., j_{r - 1}]
151 // [0 ... axis) [axis ... axis + q) [axis + q ... q + r - 1)
152 // iterate in [0 ... axis) [0 ... q) [axis ... r - 1)
153 // for j_0, j_1, ..., j_{axis-1}
154 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
155 std::string index = "j_" + std::to_string(j);
156 out << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[j] << "; " << index << "++) {\n";
157 }
158 // for i_0, i_1, ..., i_{q - 1}
159 if (q == 0)
160 out << SP << SP << "{\n"; // add a scope for local variables
161 for (size_t i = 0; i < q; i++) {
162 std::string index = "i_" + std::to_string(i);
163 out << SP << SP << "for (size_t " << index << " = " << 0 << "; " << index << " < " << fShapeIndices[i] << "; " << index << "++) {\n";
164 }
165 // for j_axis, j_{axis + 1}, ..., j_{r - 1}
166 for (size_t j = fAttrAxis; j + 1 < r; j++) {
167 std::string index = "j_" + std::to_string(j);
168 out << SP << SP << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[q + j] << "; " << index << "++) {\n";
169 }
170
171 out << SP << SP << SP << "size_t y_index = 0;\n";
172 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
173 out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[j] << ";\n";
174 }
175 for (size_t i = 0; i < q; i++) {
176 out << SP << SP << SP << "y_index += i_" + std::to_string(i) + " * " << stridesY[fAttrAxis + i] << ";\n";
177 }
178 for (size_t j = fAttrAxis; j + 1 < r; j++) {
179 out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[q + j] << ";\n";
180 }
181 // Indices
182 out << SP << SP << SP << "size_t i_index = 0;\n";
183 for (size_t i = 0; i < q; i++) {
184 out << SP << SP << SP << "i_index += i_" + std::to_string(i) + " * " << stridesIndices[i] << ";\n";
185 }
186 // K
187 out << SP << SP << SP << "size_t k = static_cast<size_t>(" << "tensor_" << fNIndices << "[i_index]" << ");\n";
188 // Input
189 out << SP << SP << SP << "size_t x_index = k * " << stridesX[fAttrAxis] << ";\n";
190 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
191 out << SP << SP << SP << "x_index += j_" + std::to_string(j) + " * " << stridesX[j] << ";\n";
192 }
193 for (size_t j = fAttrAxis + 1; j < r; j++) {
194 out << SP << SP << SP << "x_index += j_" + std::to_string(j - 1) + " * " << stridesX[j] << ";\n";
195 }
196 out << SP << SP << SP << "tensor_" << fNY << "[y_index] = tensor_" << fNX << "[x_index];\n";
197
198 // end loops j_k, j_{k + 1}, ..., j_{r - 2}
199 for (size_t j = fAttrAxis; j + 1 < r; j++) {
200 out << SP << SP << SP << "}\n";
201 }
202 // end loops i_0, i_1, ..., i_{q - 1}
203 if (q == 0)
204 out << SP << SP << "}\n"; // end of scope for q = 0
205 for (size_t i = 0; i < q; i++) {
206 out << SP << SP << "}\n";
207 }
208 // end loops j_0, j_1, ..., j_{axis - 1}
209 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
210 out << SP << "}\n";
211 }
212
213 return out.str();
214 }
215
216};
217
218}//SOFIE
219}//Experimental
220}//TMVA
221
222
223#endif //TMVA_SOFIE_ROPERATOR_RELU
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
float * q
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:94
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:227
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:192
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:202
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:288
void SetNotWritableInitializedTensor(const std::string &tensor_name)
Definition RModel.cxx:297
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
std::string Generate(std::string OpName) override
ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:46
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:44
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:42
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:47
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertValuesToString(size_t n, const T *data)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::string ConvertTypeToString(ETensorType type)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations