Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Reshape.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
2#define TMVA_SOFIE_ROPERATOR_RESHAPE
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <cassert>
9#include <sstream>
10
11namespace TMVA{
12namespace Experimental{
13namespace SOFIE{
14
16
17
19{
20
21private:
22
23 bool fVerbose = false;
24 ReshapeOpMode fOpMode = Reshape; // type of Reshape operator
25
26 int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
27 int fAxis = 1; // (for Flatten)
28
29 std::string fNData; // input data tensor name
30 std::string fNShape; // reshape tensor name
31 std::string fNOutput; // output tensor name
32 std::vector<size_t> fShapeInput; // input shape data
33 std::vector<size_t> fShapeOutput; // output shape data
34 std::vector<int64_t> fAttrAxes; // axes attributes (provided for all version of Squeeze/Unsqueeze)
35
36public:
37
38 std::string Name() const {
39 if (fOpMode == Reshape) return "Reshape";
40 if (fOpMode == Flatten) return "Flatten";
41 if (fOpMode == Squeeze) return "Squeeze";
42 if (fOpMode == Unsqueeze) return "Unsqueeze";
43 return "";
44 }
45
47 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
48 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
49 fNOutput(UTILITY::Clean_name(nameOutput))
50 {
53
55 if(!fNShape.empty()){
56 fInputTensorNames.emplace_back(fNShape);
57 }
59 }
60
61 // for squeeze/unsqueezed operators following old ONNX version (< 10)
62 // In this cases axes are passed as attribute values
63 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
64 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
66 {
68 }
69
70 // output type is same as input
71 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
72 auto ret = std::vector<ETensorType>(1, input[0]);
73 return ret;
74 }
75
76 // output shape
77 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
78 std::vector<std::vector<size_t>> ret;
79 auto & input_shape = input[0];
80
81 if (fOpMode == Reshape) {
82 if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
83 auto output_shape = input[1]; // the provided shape
86 // (input_length == output_length) is the easy case : (2,3,4) -> (2,12)
88 if ((output_length == 0 && fAllowZero == 0) || static_cast<long>(output_length) < 0) {
89 // in this case value 0 or -1 in shape are automatically corrected
90 bool replacementDone = false;
91 for (size_t i = 0; i < output_shape.size(); i++) {
92 if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
93 if (replacementDone) {
94 throw std::runtime_error("TMVA Reshape Op : output shape has multiple negative or zero values");
95 }
96 auto tmp = output_shape;
97 tmp.erase(tmp.begin() + i);
100 replacementDone = true;
101 }
102 }
103 if (fVerbose)
104 std::cout << "Reshape: correct output shape from " << ConvertShapeToString(input[1])
105 << " to " << ConvertShapeToString(output_shape) << std::endl;
106 }
108 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertShapeToString(input_shape) +
110 }
111 }
112 ret.push_back(output_shape);
113
114 } else if (fOpMode == Flatten) {
115 // flattenig case
116 size_t inputSize = ConvertShapeToLength(input_shape);
117 size_t b = input[0][0];
118 std::vector<size_t> newShape = {b, inputSize / b};
119 ret.push_back(newShape);
120
121 } else if (fOpMode == Squeeze) {
122 // squeeze
123 // assume no axis is provided - remove all axes with value equal to 1
124 auto output_shape = input[0];
125 if (input.size() == 1) {
126 size_t i = 0;
127 while (i < output_shape.size()) {
128 if (output_shape[i] == 1 ) {
129 output_shape.erase(output_shape.begin() + i);
130 }
131 else {
132 i++;
133 }
134 }
135 } else if (input.size() == 2) {
136 auto & axes = input[1];
137 for (size_t i = 0; i < axes.size(); i++){
138 if (output_shape[axes[i]] != 1)
139 throw std::runtime_error("TMVA Squeeze Op : Invalid axes : " + ConvertShapeToString(axes) +
141 output_shape.erase(output_shape.begin() + axes[i]);
142 }
143 }
144 ret.push_back(output_shape);
145 }
146
147 else if (fOpMode == Unsqueeze) {
148 // unsqueeze
149 assert(input.size() == 2);
150 auto output_shape = input[0];
151 auto &axes = input[1];
152 // output rank
153 int64_t r = input[0].size() + axes.size();
154 for (auto & a : axes) {
155 int64_t i = static_cast<int64_t>(a);
156 if ( i < -r || i > r - 1 )
157 throw std::runtime_error("TMVA Unsqueeze Op - axes input is not in correct range");
158 if (i >= 0)
159 output_shape.insert(output_shape.begin() + i, 1);
160 else
161 //negative axes
162 output_shape.insert(output_shape.end() + i + 1, 1);
163 }
164 ret.push_back(output_shape);
165 }
166 return ret;
167 }
168
169 void Initialize(RModel& model) override {
170
171 fVerbose = model.Verbose();
172 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
173 // input must be a graph input, or already initialized intermediate tensor
174 throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + " is not found in model");
175 }
177 // check if optional shape tensor exist
178 if (!fNShape.empty()) {
181 auto input_shape = static_cast<int64_t *>(dptr.get());
182 auto vec = model.GetTensorShape(fNShape);
183 assert(vec.size() == 1);
184 size_t n = vec[0]; // size of shape input tensor
185
186 std::vector<size_t> descShape(n);
187 std::copy(input_shape, input_shape + n, descShape.begin());
189 // set flag to not write tensor in weight file. Its data will be hard-coded in way model is constructed
191 } else {
192 throw std::runtime_error("TMVA Reshape Op Shape Tensor " + fNShape + " is not found in model");
193 }
194 } else if (!fAttrAxes.empty()) {
195 // case fNShape is empty and axes are provided as attributes
196 std::vector<size_t> descShape(fAttrAxes.size());
197 std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
199 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
201 } else {
202 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
203 }
204 // check if output is constant or not
206 fIsOutputConstant = true;
207 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNData).get());
209 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Output lengths");
211 if (model.Verbose()) {
212 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> " << fNOutput << " (constant) " << ConvertShapeToString(fShapeOutput) << " : " <<
214 }
215 } else {
216 // non-constant case
218 if (model.Verbose())
219 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> "<< fNOutput << " " << ConvertShapeToString(fShapeOutput) << std::endl;
220 }
221 }
222
223 std::string Generate(std::string OpName)
224 {
225 if (fIsOutputConstant) return ""; //no op for constant tensors
226
227 OpName = "op_" + OpName;
228
229 // output of reshape is same as input
232 throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
233 ConvertShapeToString(fShapeOutput) + " and input is " +
235 }
236 std::stringstream out;
237 std::string opName = "Reshape";
238 if (fOpMode == Flatten)
239 opName = "Flatten";
240 else if (fOpMode == Squeeze)
241 opName = "Squeeze";
242 else if (fOpMode == Unsqueeze)
243 opName = "Unsquueze";
244
245 out << SP << "///--------" << opName << " operator\n" << std::endl;
246 out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << length << ", " << "tensor_" << fNOutput
247 << ");\n";
248 return out.str();
249 }
250};
251
252}//SOFIE
253}//Experimental
254}//TMVA
255
256
257#endif //TMVA_SOFIE_ROPERATOR_RESHAPE
#define b(i)
Definition RSha256.hxx:100
#define a(i)
Definition RSha256.hxx:99
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
const_iterator begin() const
const_iterator end() const
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:94
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:227
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
void AddConstantTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition RModel.cxx:192
bool IsInitializedTensor(const std::string &name) const
Definition RModel.cxx:202
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
std::shared_ptr< void > GetInitializedTensorData(std::string tensor_name)
Definition RModel.cxx:288
void SetNotWritableInitializedTensor(const std::string &tensor_name)
Definition RModel.cxx:297
ROperator_Reshape(ReshapeOpMode opMode, std::vector< int64_t > attrAxes, std::string nameData, std::string nameOutput)
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input)
ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input)
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:46
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:44
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:42
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:47
const Int_t n
Definition legend1.C:16
std::string ConvertValuesToString(size_t n, const T *data)
std::string ConvertShapeToString(std::vector< size_t > shape)
std::size_t ConvertShapeToLength(std::vector< size_t > shape)
create variable transformations