Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseWhere.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseWhere = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
10
11 if (nodeproto.input_size() != 3) {
12 throw std::runtime_error("TMVA::SOFIE ONNX Parser Where op has invalid input size");
13 }
14 // condition boolean vector is input 0
15 if (!parser.IsRegisteredTensorType(nodeproto.input(0))){
16 throw std::runtime_error("TMVA::SOFIE ONNX Parser Where op has input tensor " + nodeproto.input(0)
17 + " but its type is not yet registered");
18 }
19 if (!parser.IsRegisteredTensorType(nodeproto.input(1))){
20 throw std::runtime_error("TMVA::SOFIE ONNX Parser Where op has input tensor " + nodeproto.input(1)
21 + " but its type is not yet registered");
22 }
23 if (!parser.IsRegisteredTensorType(nodeproto.input(2))){
24 throw std::runtime_error("TMVA::SOFIE ONNX Parser Where op has input tensor " + nodeproto.input(2)
25 + " but its type is not yet registered");
26 }
28 if (parser.GetTensorType(nodeproto.input(2)) != input_type) {
29 throw std::runtime_error("TMVA::SOFIE ONNX parser Where op has input tensors of different types: " +
30 nodeproto.input(2) + " : " + ConvertTypeToString(parser.GetTensorType(nodeproto.input(2))) +
31 " and " + nodeproto.input(1) + " : " + ConvertTypeToString(input_type));
32 }
33
34 std::unique_ptr<ROperator> op;
35 std::string output_name = nodeproto.output(0);
36
37 switch (input_type) {
38 //note ROPeratore_WHere signature takes as first tensor the condition
40 op.reset(new ROperator_Where<float>(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), output_name));
41 break;
43 op.reset(new ROperator_Where<int64_t>(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), output_name));
44 break;
45 default:
46 throw std::runtime_error("TMVA::SOFIE - Unsupported - Where Operator does not yet support input type " +
47 std::to_string(static_cast<int>(input_type)));
48 }
49
50 // Infer the output type
53 }
54
55 return op;
56};
57
58
59} // namespace SOFIE
60} // namespace Experimental
61} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
ParserFuncSignature ParseWhere
Definition ParseWhere.cxx:9
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
create variable transformations