Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Reshape.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
2#define TMVA_SOFIE_ROPERATOR_RESHAPE
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <cassert>
9#include <cctype>
10#include <sstream>
11#include <algorithm>
12
13namespace TMVA{
14namespace Experimental{
15namespace SOFIE{
16
18
19
21{
22
23private:
24
25 bool fVerbose = false;
26 bool fDimInput = false;
27 bool fDynamicShape = false;
28 ReshapeOpMode fOpMode = Reshape; // type of Reshape operator
29
30 int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
31 int fAxis = 1; // (for Flatten)
32
33 std::string fNData; // input data tensor name
34 std::string fNInput2; // reshape or axes tensor name depending on operator
35 std::string fNOutput; // output tensor name
36 std::vector<Dim> fShapeInput; // input shape data
37 std::vector<Dim> fShapeOutput; // output shape data
38 std::vector<Dim> fOutputShapeData; // in case output is a shape tensor we store here the shape value data (can be parametric)
39 std::vector<int64_t> fAttrAxes; // axes attributes (provided for all version of Squeeze/Unsqueeze)
40 std::vector<int64_t> fShape; // shape tensor values provided for Reshape for int shapes4
41
42public:
43
44 std::string Name() const {
45 if (fOpMode == Reshape) return "Reshape";
46 if (fOpMode == Flatten) return "Flatten";
47 if (fOpMode == Squeeze) return "Squeeze";
48 if (fOpMode == Unsqueeze) return "Unsqueeze";
49 return "";
50 }
51
53 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameInput2, std::string nameOutput)
54 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNInput2(UTILITY::Clean_name(nameInput2)),
55 fNOutput(UTILITY::Clean_name(nameOutput))
56 {
59
61 if(!fNInput2.empty()){
62 fInputTensorNames.emplace_back(fNInput2);
63 }
65 }
66
67 // for squeeze/unsqueezed operators following old ONNX version (< 10)
68 // In this cases axes are passed as attribute values
69 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
70 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
72 {
76 }
77
78
79 // output shape
80 std::vector<Dim> DoShapeInference(const std::vector<Dim> & input_shape, const std::vector<Dim> & target_shape) {
81 if (fOpMode == Reshape) {
82 // correct the provided shape (here we have the value) for 0 or -1
83 // the target_shape can be a scalar in case of not present shape input tensor
84 std::vector<Dim> output_shape = target_shape;
85 bool hasMinusOne = false;
86 bool hasZero = false;
87 for (size_t i = 0; i < output_shape.size(); i++) {
88 // case for zero values in given shape: in this case we take the corresponding value from input shape
89 if (!output_shape[i].isParam) {
90 if (output_shape[i].dim == 0) {
91 hasZero = true;
92 if (fAllowZero)
93 output_shape[i] = Dim{0};
94 else {
95 if (i > 0 && output_shape.size() != input_shape.size())
96 std::cout << "WARNING: TMVA Reshape Op : output shape has zero value at index " << i <<
97 " but input shape has a different rank than output shape" << std::endl;
98 if (i >= input_shape.size())
99 throw std::runtime_error("TMVA Reshape Op : output shape has zero value at index " + std::to_string(i) +
100 " but input shape does not have corresponding index");
101 }
103 } else if (output_shape[i].dim == static_cast<size_t>(-1)) {
104 hasMinusOne = true;
105 }
106 }
107 }
108 if (hasZero && hasMinusOne) {
109 throw std::runtime_error("TMVA Reshape Op : zero value in shape is not allowed when there is also a -1 in shape");
110 }
111 // now case of -1 in shape - we can infer the value of -1 from all other values
112 for (size_t i = 0; i < output_shape.size(); i++) {
113 if (output_shape[i] == static_cast<size_t>(-1) && !output_shape[i].isParam) {
114 auto tmp = output_shape;
115 tmp.erase(tmp.begin() + i); // erase -1 value to compute the length of the other dimensions
118 if (fVerbose)
119 std::cout << "reshape- try simplifying " << ConvertDimShapeToString(input_shape) << " with length "
120 << input_length << " to " << tmp_length << std::endl;
121
123 output_shape[i] = Dim{static_cast<size_t>(std::stoi(input_length) / std::stoi(tmp_length))};
124 else if (IsInteger(tmp_length) && std::stoi(tmp_length) == 1) {
125 output_shape[i] = Dim{input_length, static_cast<size_t>(-1)};
126 }
127 else {
128 //we can try simplifying expression if tmp_length is integer and part of input_length
129 // contains tmp_length
130 bool canSimplify = false;
131 std::vector <Dim> reduced_input;
132 if (IsInteger(tmp_length)) {
133
134 // try to tokenize with * the input length
135
136 std::stringstream ss(input_length);
137
138 std::string token;
139
140 // Tokenizing w.r.t. space '*'
141 while(getline(ss, token, '*'))
142 {
143 // remove any whitespace
144 token.erase(std::remove_if(token.begin(), token.end(),
145 [](unsigned char x) { return std::isspace(x); }), token.end());
146 if (token != tmp_length) {
147 if (IsInteger(token)) {
148 size_t il = static_cast<size_t>(std::stoi(input_length));
149 size_t tl = static_cast<size_t>(std::stoi(tmp_length));
150 if ((il % tl) == 0) {
151 canSimplify = true;
152 reduced_input.push_back(Dim{il / tl});
153 }
154 } else {
155 reduced_input.push_back(Dim{token});
156 }
157 } else {
158 // token is equal to tmp_length, can be not considered and is simplified
159 canSimplify = true;
160 }
161 }
162 }
163 if (canSimplify) {
164 // if length contains * we need to add some brackets
166 if (res_shape.find('*') != std::string::npos)
167 output_shape[i] = Dim{std::string("(") + res_shape + ")", static_cast<size_t>(-1)};
168 else
170 }
171 if (!canSimplify)
172 output_shape[i] = Dim{std::string("(") + input_length + " / (" + tmp_length + "))", static_cast<size_t>(-1)};
173 }
174
175 break; // cannot have more than -1
176 }
177 // throw std::runtime_error(
178 // "TMVA Reshape Op : output shape has multiple negative or zero values");
179 }
180
181 if (fVerbose)
182 std::cout << "Reshape: correct output shape to " << ConvertDimShapeToString(output_shape) << std::endl;
183
185 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertDimShapeToString(input_shape) +
187 }
188 return output_shape;
189
190 } else if (fOpMode == Flatten) {
191 // flatten case
192 if (fAxis < 0)
193 fAxis += input_shape.size();
194 auto s1 = std::vector<Dim>(input_shape.begin(), input_shape.begin() + fAxis);
195 auto s2 = std::vector<Dim>(input_shape.begin() + fAxis, input_shape.end());
198 std::vector<Dim> newShape = {Dim{l1}, Dim{l2}};
199 return newShape;
200 } else if (fOpMode == Squeeze) {
201 // squeeze
202 // assume no axis is provided - remove all axes with value equal to 1
204 if (fAttrAxes.empty()) {
205 size_t i = 0;
206 while (i < output_shape.size()) {
207 if (output_shape[i] == Dim{1}) {
208 output_shape.erase(output_shape.begin() + i);
209 } else {
210 i++;
211 }
212 }
213 } else {
214 auto axes = fAttrAxes;
215 for (size_t i = 0; i < axes.size(); i++) {
216 if (axes[i] < 0)
217 axes[i] += input_shape.size();
218 if (!(output_shape[axes[i]] == Dim{1}))
219 throw std::runtime_error("TMVA Squeeze Op : Invalid axis value " + std::to_string(axes[i]) +
221 }
222 // for calling vector::erase we must sort axes in decreasing order to avoid
223 std::sort(axes.begin(), axes.end(), std::greater<int>());
224 for (auto & axis : axes) {
225 output_shape.erase(output_shape.begin() + axis);
226 }
227 }
228 return output_shape;
229 }
230 else if (fOpMode == Unsqueeze) {
231 // unsqueeze
232 assert(!fAttrAxes.empty());
234 auto &axes = fAttrAxes;
235 // output rank
236 int64_t r = input_shape.size() + axes.size();
237 for (auto &a : axes) {
238 int64_t i = static_cast<int64_t>(a);
239 if (i < -r || i > r - 1)
240 throw std::runtime_error("TMVA Unsqueeze Op - axes input is not in correct range");
241 if (i >= 0)
242 output_shape.insert(output_shape.begin() + i, Dim{1});
243 else
244 // negative axes
245 output_shape.insert(output_shape.end() + i + 1, Dim{1});
246 }
247 return output_shape;
248 }
249 throw std::runtime_error("TMVA Reshape Op : Invalid ReshapeOpMode");
250 return {Dim{}};
251 }
252
253 void Initialize(RModel& model) override {
254
255 fVerbose = model.Verbose();
256 if (fVerbose)
257 std::cout << "initialize reshape op type " << fOpMode << " - for input " << fNData
258 << " to shape given by " << fNInput2 << std::endl;
259
260 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
261 // input must be a graph input, or already initialized intermediate tensor
262 throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + " is not found in model");
263 }
266 // check if optional tensor exists defining shape or axes
267 if (!fNInput2.empty()) {
269 if (model.IsInitializedTensor(fNInput2)) {
270 // assume input shape is an initialized tensor
272 auto values = static_cast<int64_t *>(dptr.get());
273 auto vec = model.GetTensorShape(fNInput2);
274 size_t n = 1;
275 if (vec.size() > 0)
276 n = vec[0]; // size of shape input tensor
277 // copy values in fShape vector or fAttrAxes
278 if (fOpMode == Reshape)
279 fShape = std::vector<int64_t>(values, values + n);
280 else
281 fAttrAxes = std::vector<int64_t>(values, values + n);
282
283 std::vector<Dim> targetShape(fShape.begin(),fShape.end());
285 // set flag to not write tensor in weight file. Its data will be hard-coded in way model is constructed
287 } else if (model.IsShapeTensor(fNInput2)) {
290 if (model.Verbose())
291 std::cout << "Reshape op - get output shape from shape tensor " << fNInput2 << " with value " << ConvertDimShapeToString(shapeData) << std::endl;
292 } else {
293 // we cannot get shape at initialization time but at run-time
294 fDynamicShape = true;
295 // size of shape output us given by size of shape input tensor
296 if (model.IsDynamicTensor(fNInput2)) {
297 throw std::runtime_error("TMVA Reshape Op 2nd input Tensor " + fNInput2 + " cannot have dynamic shape");
298 }
299 auto shapeInput2 = model.GetTensorShape(fNInput2);
300 fShapeOutput.resize(shapeInput2[0]);
301 for (size_t i = 0; i < fShapeOutput.size(); i++) {
302 fShapeOutput[i] = Dim{ std::string("s_") + fNOutput + "_" + std::to_string(i)};
303 }
304 }
305 } else {
306 throw std::runtime_error("TMVA Reshape Op 2nd input Tensor " + fNInput2 + " is not found in model");
307 }
308 } else if (!fAttrAxes.empty()) {
309 // case fNShape is empty and axes are provided as attributes (e.g. for Unsqueeze)
310 fShapeOutput = DoShapeInference(fShapeInput, std::vector<Dim>{});
311 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
312 fShapeOutput = DoShapeInference(fShapeInput, std::vector<Dim>{});
313 } else {
314 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
315 }
316 // check if output is constant or not
318 fIsOutputConstant = true;
319 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNData).get());
322 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Output lengths");
324 if (model.Verbose()) {
325 std::cout << Name() << " : " << fNData << " " << ConvertDimShapeToString(fShapeInput) << " --> " << fNOutput << " (constant) " << ConvertDimShapeToString(fShapeOutput) << " : " <<
327 }
328 }
329 // for input shape tensors we can have it if output shape is size==1 or a scalar
330 else if (model.IsShapeTensor(fNData) && fShapeOutput.size() <=1) {
331 // not sure if we ever end-up here - maybe reshaping from scalar to vector or viceversa
332 fIsOutputParamShape = true;
335 if (model.Verbose()) {
336 std::cout << Name() << " : " << fNData << " " << ConvertDimShapeToString(fShapeInput) << " --> " << fNOutput << " (shape) " << ConvertDimShapeToString(fShapeOutput) << " : " <<
338 }
339 }
340 else {
341 // non-constant case
343 if (model.Verbose())
344 std::cout << Name() << " : " << fNData << " " << ConvertDimShapeToString(fShapeInput) << " --> "<< fNOutput << " " << ConvertDimShapeToString(fShapeOutput) << std::endl;
345 }
346 }
347
348 std::string Generate(std::string opName) override {
349
350
351 std::stringstream out;
352 std::string opType = "Reshape";
353 if (fOpMode == Flatten)
354 opType = "Flatten";
355 else if (fOpMode == Squeeze)
356 opType = "Squeeze";
357 else if (fOpMode == Unsqueeze)
358 opType = "Unsquueze";
359
360 out << SP << "///--------" << opType << " operator " << opName << " --> " << ConvertDimShapeToString(fShapeOutput) << "\n";
361
362 if (fIsOutputConstant) return out.str(); //no op for constant tensors
363
365 // no code to generate here for param shape output. Tensor output is defined in Session constructor
366 out << "//----------------output is a shape tensor----------\n";
367 for (int i = 0; i < static_cast<int>(fShapeOutput[0].dim); i++) {
368 out << SP << "tensor_" << fNOutput << "[" << i << " ] = " << fOutputShapeData[i].GetVal() << ";\n";
369 }
370 return out.str();
371 }
372
373 // in case of dynamic output shape we need to set the shape value from input shape tensor
374 // and take case of the zero values
375 if (fDynamicShape) {
376 for (size_t i = 0; i < fShapeOutput.size(); i++) {
377 // since fNInput2 values are int64_t, should we check if they are negative?
378 out << SP << "size_t " << fShapeOutput[i].param << " = " << "tensor_" << fNInput2 << "[" << i << "];\n";
379 if (!fAllowZero)
380 out << SP << "if (tensor_" << fNInput2 << "[" << i << "] <= 0 ) "
381 << fShapeOutput[i].param << " = " << fShapeInput[i] << ";\n";
382 }
383 }
384
385 // output of reshape is same as input
388 if (lengthOut != lengthIn) {
389 // check needs to be done at run-time
390 out << SP << "if (" << lengthOut << "!=" << lengthIn << ")\n";
391 out << SP << SP << "throw std::runtime_error(\"TMVA SOFIE Reshape " << opName << " output length "
392 << lengthOut << " is different than input one " << lengthIn << "\");\n";
393 }
394
395
396 out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << lengthIn << ", " << "tensor_" << fNOutput
397 << ");\n";
398 return out.str();
399 }
400};
401
402}//SOFIE
403}//Experimental
404}//TMVA
405
406
407#endif //TMVA_SOFIE_ROPERATOR_RESHAPE
#define a(i)
Definition RSha256.hxx:99
#define s1(x)
Definition RSha256.hxx:91
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
const_iterator begin() const
const_iterator end() const
std::vector< size_t > GetTensorShape(const std::string &name) const
Definition RModel.cxx:64
std::vector< Dim > GetDimTensorShape(const std::string &name) const
Definition RModel.cxx:100
bool IsDynamicTensor(const std::string &name) const
Definition RModel.cxx:286
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:301
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:157
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:232
bool IsShapeTensor(const std::string &name) const
check if a tensor is a shape tensor
Definition RModel.cxx:260
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:273
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:366
void SetNotWritableInitializedTensor(const std::string &tensor_name)
Definition RModel.cxx:375
ETensorType GetTensorType(std::string name) const
Definition RModel.cxx:125
const std::vector< Dim > & GetShapeTensorValues(const std::string &tensor_name) const
Definition RModel.cxx:268
void AddShapeTensor(const std::string &name, const std::vector< Dim > &shapeValues, bool scalar=false)
Definition RModel.cxx:242
ROperator_Reshape(ReshapeOpMode opMode, std::vector< int64_t > attrAxes, std::string nameData, std::string nameOutput)
ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameInput2, std::string nameOutput)
std::vector< Dim > DoShapeInference(const std::vector< Dim > &input_shape, const std::vector< Dim > &target_shape)
std::string Generate(std::string opName) override
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:50
bool fIsOutputParamShape
flag to identify of the output represents a parametric shape (can be known at compile time)
Definition ROperator.hxx:48
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:47
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:45
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:51
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
std::string ConvertDimShapeToString(const std::vector< Dim > &shape)
std::size_t ConvertShapeToLength(const std::vector< size_t > &shape)
std::string ConvertValuesToString(size_t n, const T *data, size_t maxprint=-1)
std::vector< size_t > ConvertShapeToInt(const std::vector< Dim > &shape)
Convert shape based on Dim to integer format.
std::string ConvertDimShapeToLength(const std::vector< Dim > &shape)
bool IsInteger(const std::string &s)
create variable transformations