Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseClip.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9// ---------------------------------------------------------------------------
10// ParseClip
11//
12// ONNX Clip node inputs (all optional except X):
13// input(0) : X — data tensor to clip (required)
14// input(1) : min — scalar lower bound (optional)
15// input(2) : max — scalar upper bound (optional)
16//
17// ONNX Clip node output:
18// output(0): Y — clipped output tensor
19//
20// If min / max inputs are absent the node may have input_size < 3.
21// An absent optional input is represented in the ONNX protobuf as an
22// empty string "".
23// ---------------------------------------------------------------------------
24
26 const onnx::NodeProto &nodeproto) {
27
28 // ---- validate input count -------------------------------------------
29 // Clip requires at least 1 input (X); min and max are optional
30 if (nodeproto.input_size() < 1) {
31 throw std::runtime_error(
32 "TMVA::SOFIE ONNX Parser Clip op has invalid input size " +
33 std::to_string(nodeproto.input_size()) + " (expected 1, 2 or 3)");
34 }
35
36 // ---- main input X must be registered --------------------------------
37 if (!parser.IsRegisteredTensorType(nodeproto.input(0))) {
38 throw std::runtime_error(
39 "TMVA::SOFIE ONNX Parser Clip op has input tensor " +
40 nodeproto.input(0) + " but its type is not yet registered");
41 }
42
44
45
46 std::string minName = (nodeproto.input_size() > 1) ? nodeproto.input(1) : "";
47 std::string maxName = (nodeproto.input_size() > 2) ? nodeproto.input(2) : "";
48
49 // ---- if min/max are provided they must match the data type ----------
50 if (!minName.empty() && parser.IsRegisteredTensorType(minName)) {
51 if (parser.GetTensorType(minName) != input_type) {
52 throw std::runtime_error(
53 "TMVA::SOFIE ONNX Parser Clip op: min tensor " + minName +
54 " type " + ConvertTypeToString(parser.GetTensorType(minName)) +
55 " does not match input type " + ConvertTypeToString(input_type));
56 }
57 }
58 if (!maxName.empty() && parser.IsRegisteredTensorType(maxName)) {
59 if (parser.GetTensorType(maxName) != input_type) {
60 throw std::runtime_error(
61 "TMVA::SOFIE ONNX Parser Clip op: max tensor " + maxName +
62 " type " + ConvertTypeToString(parser.GetTensorType(maxName)) +
63 " does not match input type " + ConvertTypeToString(input_type));
64 }
65 }
66
67 // ---- build the operator ---------------------------------------------
68 std::unique_ptr<ROperator> op;
69 std::string output_name = nodeproto.output(0);
70
71 switch (input_type) {
73 op.reset(new ROperator_Clip<float>(
75 break;
77 op.reset(new ROperator_Clip<double>(
79 break;
83 break;
87 break;
88 default:
89 throw std::runtime_error(
90 "TMVA::SOFIE - Unsupported - Clip Operator does not yet support "
91 "input type " + ConvertTypeToString(input_type));
92 }
93
94 // ---- register output tensor type ------------------------------------
97 }
98
99 return op;
100};
101
102} // namespace SOFIE
103} // namespace Experimental
104} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseClip
Definition ParseClip.cxx:25
std::string ConvertTypeToString(ETensorType type)
create variable transformations