Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Gather.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_GATHER
2#define TMVA_SOFIE_ROPERATOR_GATHER
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9#include <stdexcept>
10#include <string>
11
12namespace TMVA{
13namespace Experimental{
14namespace SOFIE{
15
16template <typename T>
18{
19private:
20
21 int64_t fAttrAxis = 0;
22
23 std::string fNX;
24 std::string fNIndices;
25 std::string fNY;
26
27 std::vector<size_t> fShapeX;
28 std::vector<size_t> fShapeIndices;
29 std::vector<size_t> fShapeY;
30
31 std::vector<int64_t> fIndices;
32
33 std::string fType;
34
35public:
37 ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY):
38 fAttrAxis(attrAxis), fNX(UTILITY::Clean_name(nameX)), fNIndices(UTILITY::Clean_name(nameIndices)), fNY(UTILITY::Clean_name(nameY)) {
39 }
40
41 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
42 return input;
43 }
44
45 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
46 auto ret = input;
47 return ret;
48 }
49
50 void Initialize(RModel& model) {
51 if (!model.CheckIfTensorAlreadyExist(fNX)) {
52 throw std::runtime_error("TMVA SOFIE Gather Op Input Tensor " + fNX + " is not found in model");
53 }
54 fShapeX = model.GetTensorShape(fNX);
55 if (!model.IsInitializedTensor(fNIndices)) {
56 throw
57 std::runtime_error("TMVA::SOFIE - Tensor " + fNIndices + " is not initialized.");
58 }
59 int64_t* indicesData = static_cast<int64_t*>(model.GetInitializedTensorData(fNIndices).get());
61 size_t q = fShapeIndices.size();
62 // Axis in range [0, r) where r=rank(X)
63 size_t r = fShapeX.size();
64 // Set the axis
65 if (fAttrAxis < 0) {
66 fAttrAxis = fAttrAxis + int64_t(r);
67 }
68 // Indices of size q
69 // empty fShapeIndices is a scalar value for the indices
71 fIndices = std::vector<int64_t>(indicesData, indicesData + indicesLength);
72 for (size_t i = 0; i < indicesLength; i++) {
73 if (fIndices[i] < 0) {
75 }
76 }
77 // Output shape
78 if (fShapeY.empty()) {
79 fShapeY.resize(q + r - 1);
80 if (fAttrAxis > 0) {
81 // Copy shape of X[0, ..., axis) to Shape of Y[0, ..., axis)
82 std::copy(fShapeX.begin(), fShapeX.begin() + fAttrAxis, fShapeY.begin());
83 }
84 // Set shape of Y[axis, ..., axis + q)
85 for (size_t i = 0; i < q; i++) {
87 }
88 // Copy shape of X[axis + 1, ..., axis + r) to shape of Y[axis + q, ... q + r - 1)
89 std::copy(fShapeX.begin() + fAttrAxis + 1, fShapeX.end(), fShapeY.begin() + fAttrAxis + q);
90 }
91 // Add output tensor
94 }
95
96 std::string Generate(std::string OpName) {
97 OpName = "op_" + OpName;
98 std::stringstream out;
99 // The shape of the output is q + r - 1
100 size_t r = fShapeX.size();
101 // Indices of shape q
102 size_t q = fShapeIndices.size();
103 // Strides
104 std::vector<size_t> stridesX = UTILITY::ComputeStrideFromShape(fShapeX);
105 std::vector<size_t> stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
107 // Indices vector
108 out << SP << "std::vector<int64_t> " << OpName << "_indices = {";
110 for (size_t i = 0; i < indicesLength; i++) {
111 out << fIndices[i] << (i + 1 < indicesLength? ", " : "};\n");
112 }
113 // Fill the output Y[j_0, j_1, ..., j_{axis - 1}, i_0, i_1, ..., i_{q - 1}, j_{axis + 1}, ..., j_{r - 1}]
114 // [0 ... axis) [axis ... axis + q) [axis + q ... q + r - 1)
115 // iterate in [0 ... axis) [0 ... q) [axis ... r - 1)
116 // for j_0, j_1, ..., j_{axis-1}
117 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
118 std::string index = "j_" + std::to_string(j);
119 out << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[j] << "; " << index << "++) {\n";
120 }
121 // for i_0, i_1, ..., i_{q - 1}
122 for (size_t i = 0; i < q; i++) {
123 std::string index = "i_" + std::to_string(i);
124 out << SP << SP << "for (size_t " << index << " = " << 0 << "; " << index << " < " << fShapeIndices[i] << "; " << index << "++) {\n";
125 }
126 // for j_axis, j_{axis + 1}, ..., j_{r - 1}
127 for (size_t j = fAttrAxis; j + 1 < r; j++) {
128 std::string index = "j_" + std::to_string(j);
129 out << SP << SP << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[q + j] << "; " << index << "++) {\n";
130 }
131
132 out << SP << SP << SP << "size_t y_index = 0;\n";
133 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
134 out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[j] << ";\n";
135 }
136 for (size_t i = 0; i < q; i++) {
137 out << SP << SP << SP << "y_index += i_" + std::to_string(i) + " * " << stridesY[fAttrAxis + i] << ";\n";
138 }
139 for (size_t j = fAttrAxis; j + 1 < r; j++) {
140 out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[q + j] << ";\n";
141 }
142 // Indices
143 out << SP << SP << SP << "size_t i_index = 0;\n";
144 for (size_t i = 0; i < q; i++) {
145 out << SP << SP << SP << "i_index += i_" + std::to_string(i) + " * " << stridesIndices[i] << ";\n";
146 }
147 // K
148 out << SP << SP << SP << "size_t k = static_cast<size_t>(" << OpName << "_indices[i_index]" << ");\n";
149 // Input
150 out << SP << SP << SP << "size_t x_index = k * " << stridesX[fAttrAxis] << ";\n";
151 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
152 out << SP << SP << SP << "x_index += j_" + std::to_string(j) + " * " << stridesX[j] << ";\n";
153 }
154 for (size_t j = fAttrAxis + 1; j < r; j++) {
155 out << SP << SP << SP << "x_index += j_" + std::to_string(j - 1) + " * " << stridesX[j] << ";\n";
156 }
157 out << SP << SP << SP << "tensor_" << fNY << "[y_index] = tensor_" << fNX << "[x_index];\n";
158
159 // end loops j_k, j_{k + 1}, ..., j_{r - 2}
160 for (size_t j = fAttrAxis; j + 1 < r; j++) {
161 out << SP << SP << SP << "}\n";
162 }
163 // end loops i_0, i_1, ..., i_{q - 1}
164 for (size_t i = 0; i < q; i++) {
165 out << SP << SP << "}\n";
166 }
167 // end loops j_0, j_1, ..., j_{axis - 1}
168 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
169 out << SP << "}\n";
170 }
171
172 return out.str();
173 }
174
175};
176
177}//SOFIE
178}//Experimental
179}//TMVA
180
181
182#endif //TMVA_SOFIE_ROPERATOR_RELU
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
float * q
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:91
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:187
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:116
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:172
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:248
std::string Generate(std::string OpName)
ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:41
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertTypeToString(ETensorType type)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations