Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseWhere.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseWhere = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
10
11 if (nodeproto.input_size() != 3) {
12 throw std::runtime_error("TMVA::SOFIE ONNX Parser Where op has invalid input size");
13 }
14 // condition boolean vector is input 0
15 if (!parser.IsRegisteredTensorType(nodeproto.input(1))){
16 throw std::runtime_error("TMVA::SOFIE ONNX Parser Where op has input tensor " + nodeproto.input(1)
17 + " but its type is not yet registered");
18 }
19 if (!parser.IsRegisteredTensorType(nodeproto.input(2))){
20 throw std::runtime_error("TMVA::SOFIE ONNX Parser Where op has input tensor " + nodeproto.input(2)
21 + " but its type is not yet registered");
22 }
24 if (parser.GetTensorType(nodeproto.input(2)) != input_type) {
25 throw std::runtime_error("TMVA::SOFIE ONNX parser Where op has input tensors of different types: " +
26 nodeproto.input(2) + " : " + ConvertTypeToString(parser.GetTensorType(nodeproto.input(2))) +
27 " and " + nodeproto.input(1) + " : " + ConvertTypeToString(input_type));
28 }
29
30 std::unique_ptr<ROperator> op;
31 std::string output_name = nodeproto.output(0);
32
33 switch (input_type) {
35 op.reset(new ROperator_Where<float>(nodeproto.input(1), nodeproto.input(2), nodeproto.input(0), output_name));
36 break;
38 op.reset(new ROperator_Where<int64_t>(nodeproto.input(1), nodeproto.input(2), nodeproto.input(0), output_name));
39 break;
40 default:
41 throw std::runtime_error("TMVA::SOFIE - Unsupported - Where Operator does not yet support input type " +
42 std::to_string(static_cast<int>(input_type)));
43 }
44
45 // Infer the output type
48 }
49
50 return op;
51};
52
53
54} // namespace SOFIE
55} // namespace Experimental
56} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
ParserFuncSignature ParseWhere
Definition ParseWhere.cxx:9
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
create variable transformations