Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseSplit.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseSplit = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
11
12 std::string input_name = nodeproto.input(0);
15 } else {
16 throw std::runtime_error("TMVA::SOFIE ONNX Parser Split op has input tensor" + input_name +
17 " but its type is not yet registered");
18 }
19
20 std::string split_name;
21 if (nodeproto.input_size() > 1) {
22 split_name = nodeproto.input(1);
24 throw std::runtime_error("TMVA::SOFIE ONNX Parser Split op has input tensor" + split_name +
25 " but its type is not yet registered");
26 }
27 }
28
29 int axis = 0;
30 int num_outputs = 0;
31 for (int i = 0; i < nodeproto.attribute_size(); i++) {
32 std::string attribute_name = nodeproto.attribute(i).name();
33 if (attribute_name == "axis") {
34 axis = nodeproto.attribute(i).i();
35 }
36 else if (attribute_name == "num_outputs") {
37 num_outputs = nodeproto.attribute(i).i();
38 }
39 else
40 throw std::runtime_error("TMVA::SOFIE ONNX Parser Split operator: attribute" + attribute_name + "is not yet supported");
41 }
42
43 // number of splits are given by the number of output tensors
44 int output_size = nodeproto.output_size();
45 std::vector<std::string> output_names(output_size);
46 for (int i = 0; i < output_size; i++)
47 output_names[i] = nodeproto.output(i);
48
50 throw std::runtime_error("TMVA::SOFIE ONNX Parser Split - invalid output size: " + std::to_string(output_size) + " instead of " +
51 std::to_string(num_outputs));
52
53 std::unique_ptr<ROperator> op(new ROperator_Split(input_name, split_name, axis, output_names));
54
55
56 for (int i = 0; i < output_size; i++) {
57 if (!parser.IsRegisteredTensorType(output_names[i])) {
59 }
60 }
61
62 return op;
63};
64
65} // namespace SOFIE
66} // namespace Experimental
67} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseSplit
Definition ParseSplit.cxx:9
create variable transformations