Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseConstant.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9// same function used to parse Constant and ConstantOfShape
10
11ParserFuncSignature ParseConstant = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
12 std::string input_name;
13 auto ninputs = nodeproto.input_size();
14 bool isConstantOfShape = false;
15 // case of ConstantOfShape (Constant has zero inputs)
16 if (ninputs > 0) {
17 input_name = nodeproto.input(0);
18 isConstantOfShape = true;
20 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape op has input tensor" + input_name +
21 " but its type is not yet registered");
22 }
23 }
24
25 if (parser.Verbose()) {
26 std::cout << "\t.... ";
28 std::cout << "ConstantOfShape " << nodeproto.input(0) << " -> ";
29 else
30 std::cout << "Constant --> ";
31 std::cout << nodeproto.output(0) << std::endl;
32 }
33
34 std::unique_ptr<ROperator> op;
35 std::string attr_type;
36
37 std::string output_name = nodeproto.output(0);
39 std::vector<std::size_t> shape; // output shape (use in case of constant operator)
40 // it should be only one attribute (Constant or 1 or 0 COnstant of Shape)
41 if (nodeproto.attribute_size() > 1)
42 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant or ConstantOfShape and attribute size is larger than 1");
43 if (nodeproto.attribute_size() > 0) {
44 std::string attribute_name = nodeproto.attribute(0).name();
45 // tensor input
46 if (attribute_name == "value") {
47 const onnx::TensorProto & t = nodeproto.attribute(0).t();
48 output_type = static_cast<ETensorType>(t.data_type());
49
50 std::size_t length = 1;
51 for (int j = 0; j < t.dims_size(); j++) {
52 shape.push_back(t.dims(j));
53 length *= t.dims(j);
54 }
56 // value tensor should be one-element tensor
57 if (length != 1)
58 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape has invalid tensor size " + std::to_string(length));
59 }
60 switch(output_type) {
61 // to get the tensor values one needs to use the given data types or the raw_data.
62 // it depends how the operator was created. We cannot get size of the raw_data
63 case ETensorType::INT32: {
64 std::vector<int32_t> values(length);
65 if (t.int32_data_size() == int(length)) {
66 for (size_t i = 0; i < length; i++)
67 values[i] = t.int32_data(i);
68 } else {
69 auto raw_data_ptr = reinterpret_cast<int32_t *>(const_cast<char *>(t.raw_data().c_str()));
70 std::memcpy(values.data(), raw_data_ptr, length * sizeof(int32_t));
71 }
72 op.reset(new ROperator_Constant<int32_t>("int32_t", values, shape, input_name, output_name));
73 break;
74 }
75 case ETensorType::INT64: {
76 std::vector<int64_t> values(length);
77 if (t.int64_data_size() == int(length)) {
78 for (size_t i = 0; i < length; i++)
79 values[i] = t.int64_data(i);
80 } else { // cannot get size of raw data : assume is ok
81 auto raw_data_ptr = reinterpret_cast<int64_t *>(const_cast<char *>(t.raw_data().c_str()));
82 std::memcpy(values.data(), raw_data_ptr, length * sizeof(int64_t));
83 }
84 op.reset(new ROperator_Constant<int64_t>("int64_t", values, shape, input_name, output_name));
85 break;
86 }
87 case ETensorType::FLOAT: {
88 std::vector<float> values(length);
89 if (t.float_data_size() == int(length)) {
90 for (size_t i = 0; i < length; i++)
91 values[i] = t.float_data(i);
92 } else {
93 auto raw_data_ptr = reinterpret_cast<float *>(const_cast<char *>(t.raw_data().c_str()));
94 std::memcpy(values.data(), raw_data_ptr, length * sizeof(float));
95 }
96 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
97 break;
98 }
100 std::vector<double> values(length);
101 if (t.double_data_size() == int(length)) {
102 for (size_t i = 0; i < length; i++)
103 values[i] = t.double_data(i);
104 } else {
105 auto raw_data_ptr = reinterpret_cast<double *>(const_cast<char *>(t.raw_data().c_str()));
106 std::memcpy(values.data(), raw_data_ptr, length * sizeof(double));
107 }
108 op.reset(new ROperator_Constant<double>("double",values, shape, input_name, output_name));
109 break;
110 }
111 case ETensorType::BOOL: {
112 std::vector<bool> values(length);
113 auto raw_data_ptr = reinterpret_cast<bool *>(const_cast<char *>(t.raw_data().c_str()));
114 // cannot use values.data() for vector of bools
115 std::copy(raw_data_ptr, raw_data_ptr + length, values.begin());
116 //std::memcpy(values.data(), raw_data_ptr, length * sizeof(float));
117 op.reset(new ROperator_Constant<bool>("bool",values, shape, input_name, output_name));
118 break;
119 }
120 default:
121 throw std::runtime_error("Data type in Constant op attribute " + ConvertTypeToString(output_type) +
122 " is not supported!\n");
123 }
124 }
125 else {
126 // neither constant nor ConstantOfShape
127 if (!isConstantOfShape) {
128 // case of ConstantOfShape
129 if (attribute_name == "value_float") {
130 std::vector<float> values(1);
131 values[0] = nodeproto.attribute(0).f();
132 shape.push_back(1);
133 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
134 }
135 else if (attribute_name == "value_floats") {
136 auto values = std::vector<float>({nodeproto.attribute(0).floats().begin(), nodeproto.attribute(0).floats().end()});
137 shape.push_back(values.size());
138 op.reset(new ROperator_Constant<float>("float",values, shape, input_name, output_name));
139 }
140 else if (attribute_name == "value_int") {
141 std::vector<int64_t> values(1);
142 values[0] = nodeproto.attribute(0).i();
143 shape.push_back(1);
144 op.reset(new ROperator_Constant<int64_t>("int64_t",values, shape, input_name, output_name));
145 }
146 else if (attribute_name == "value_ints") {
147 auto values = std::vector<int64_t>({nodeproto.attribute(0).ints().begin(), nodeproto.attribute(0).ints().end()});
148 shape.push_back(values.size());
149 op.reset(new ROperator_Constant<int64_t>("int64_t",values, shape, input_name, output_name));
150 } else {
151 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant op: not yet supporting attribute " + attribute_name);
152 }
153 } else {
154 throw std::runtime_error("TMVA::SOFIE ONNX Parser ConstantOfShape op: parsed invalid attribute " + attribute_name);
155 }
156 }
157
158 // case when there is no attribute
159 } else {
160 // case of Constant of Shape : if attribute is not there use by default float type with zero values
161 if (isConstantOfShape) {
162 std::vector<float> values(1);
163 std::vector<size_t> constantShape(1,1);
165 } else {
166 throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant has no attribute");
167 }
168 }
169
170 if (!parser.IsRegisteredTensorType(output_name)) {
172 }
173
174 if (parser.Verbose())
175 std::cout << "\t ParseConstant: operator created\n";
176
177 return op;
178};
179
180} // namespace SOFIE
181} // namespace Experimental
182} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
const_iterator end() const
void RegisterTensorType(const std::string &, ETensorType)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseConstant
create variable transformations