Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseBasicBinary.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9template <EBasicBinaryOperator Op>
10std::unique_ptr<ROperator> ParseBasicBinary(RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto)
11{
13
14 for (int i = 0; i < 2; ++i) {
15 auto input_name = nodeproto.input(i);
17 // according to ONNX both inputs have same type
18 if (i == 0)
20 else {
22 if (input_type2 != input_type) {
23 throw
24 std::runtime_error("TMVA::SOFIE ONNX parser Binary op has input tensors of different types: " +
26 " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type));
27 }
28 }
29 } else {
30 throw std::runtime_error("TMVA::SOFIE ONNX Parser Binary op has input tensor " + input_name +
31 " but its type is not yet registered");
32 }
33 }
34
35 std::unique_ptr<ROperator> op;
36 std::string output_name = nodeproto.output(0);
37
38 switch (input_type) {
41 break;
44 break;
47 break;
50 break;
51 default:
52 throw std::runtime_error("TMVA::SOFIE - Unsupported - Binary Operator does not yet support input type " +
53 std::to_string(static_cast<int>(input_type)));
54 }
55
56 // Infer the output type
59 }
60
61 return op;
62};
63
64
65// Parse Add
69
70// Parse Sub
74
75// Parse Mul
79
80// Parse Div
84
85// Parse Pow
89
90// Mod (and fmod) is a special case di BasicBinary
91
92ParserFuncSignature ParseMod = [] (RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
93
95 for (int i = 0; i < 2; ++i) {
96 auto input_name = nodeproto.input(i);
98 // according to ONNX both inputs have same type
99 if (i == 0)
101 else {
103 if (input_type2 != input_type) {
104 throw
105 std::runtime_error("TMVA::SOFIE ONNX parser Binary op has input tensors of different types: " +
107 " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type));
108 }
109 }
110 } else {
111 throw std::runtime_error("TMVA::SOFIE ONNX Parser Binary op has input tensor " + input_name +
112 " but its type is not yet registered");
113 }
114 }
115 // in case of Mod there can be an attribute
116 int fmod = 0;
117 if (nodeproto.attribute_size() > 0) {
118 fmod = nodeproto.attribute(0).i();
119 }
120 std::unique_ptr<ROperator> op;
121 std::string output_name = nodeproto.output(0);
122
123 switch (input_type) {
126 break;
129 break;
131 if (fmod == 1)
133 else
135 break;
137 if (fmod == 1)
139 else
141 break;
142 default:
143 throw std::runtime_error("TMVA::SOFIE - Unsupported - Binary Operator does not yet support input type " +
144 std::to_string(static_cast<int>(input_type)));
145 }
146
147 // Infer the output type
148 if (!parser.IsRegisteredTensorType(output_name)) {
150 }
151
152 return op;
153};
154
155
156} // namespace SOFIE
157} // namespace Experimental
158} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseDiv
ParserFuncSignature ParseSub
std::unique_ptr< ROperator > ParseBasicBinary(RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto)
ParserFuncSignature ParseAdd
ParserFuncSignature ParseMod
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseMul
ParserFuncSignature ParsePow
create variable transformations