ROOT
v6-32
Reference Guide
Loading...
Searching...
No Matches
ParseGather.cxx
Go to the documentation of this file.
1
#include "
TMVA/RModelParser_ONNX.hxx
"
2
#include "
TMVA/ROperator_Gather.hxx
"
3
#include "onnx_proto3.pb.h"
4
#include <stdexcept>
5
6
namespace
TMVA
{
7
namespace
Experimental {
8
namespace
SOFIE {
9
10
ParserFuncSignature
ParseGather
= [](
RModelParser_ONNX
&parser,
const
onnx::NodeProto &
nodeproto
) {
11
ETensorType
input_type
=
ETensorType::UNDEFINED
;
12
auto
input_name
=
nodeproto
.input(0);
13
if
(parser.
IsRegisteredTensorType
(
input_name
)) {
14
input_type
= parser.
GetTensorType
(
input_name
);
15
}
else
{
16
throw
std::runtime_error(
"TMVA::SOFIE ONNX Parser Gather op has input tensor"
+
input_name
+
17
" but its type is not yet registered"
);
18
}
19
20
ETensorType
indices_type
=
ETensorType::UNDEFINED
;
21
auto
indices_name
=
nodeproto
.input(1);
22
// indices_type can be an initialized tensor, no need to emit an error if it is not registered
23
if
(parser.
IsRegisteredTensorType
(
indices_name
)) {
24
indices_type
= parser.
GetTensorType
(
indices_name
);
25
if
(
indices_type
!=
ETensorType::INT64
&&
indices_type
!=
ETensorType::INT32
) {
26
throw
27
std::runtime_error(
"TMVA::SOFIE ONNX Parser Gather op Indices tensor type not supported."
);
28
}
29
}
30
31
std::unique_ptr<ROperator>
op
;
32
std::string
output_name
=
nodeproto
.output(0);
33
int64_t
attr_axis
= 0;
34
if
(
nodeproto
.attribute_size() == 1) {
35
attr_axis
=
nodeproto
.attribute(0).i();
36
}
37
38
switch
(
input_type
) {
39
case
ETensorType::FLOAT
:
40
op
.reset(
new
ROperator_Gather<float>
(
attr_axis
,
input_name
,
indices_name
,
nodeproto
.output(0)));
41
break
;
42
default
:
43
throw
std::runtime_error(
"TMVA::SOFIE - Unsupported - Operator Gather does not yet support input type "
+
44
ConvertTypeToString
(
input_type
));
45
}
46
47
if
(!parser.
IsRegisteredTensorType
(
output_name
)) {
48
parser.
RegisterTensorType
(
output_name
,
input_type
);
49
}
50
51
return
op
;
52
};
53
54
}
// namespace SOFIE
55
}
// namespace Experimental
56
}
// namespace TMVA
RModelParser_ONNX.hxx
ROperator_Gather.hxx
TRangeDynCast
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Definition
TCollection.h:358
ROOT::Detail::TRangeCast
Definition
TCollection.h:311
TMVA::Experimental::SOFIE::RModelParser_ONNX
Definition
RModelParser_ONNX.hxx:27
TMVA::Experimental::SOFIE::RModelParser_ONNX::IsRegisteredTensorType
bool IsRegisteredTensorType(const std::string &)
Definition
RModelParser_ONNX.cxx:184
TMVA::Experimental::SOFIE::RModelParser_ONNX::RegisterTensorType
void RegisterTensorType(const std::string &, ETensorType)
Definition
RModelParser_ONNX.cxx:179
TMVA::Experimental::SOFIE::RModelParser_ONNX::GetTensorType
ETensorType GetTensorType(const std::string &name)
Definition
RModelParser_ONNX.cxx:189
TMVA::Experimental::SOFIE::ParseGather
ParserFuncSignature ParseGather
Definition
ParseGather.cxx:10
TMVA::Experimental::SOFIE::ETensorType
ETensorType
Definition
SOFIE_common.hxx:25
TMVA::Experimental::SOFIE::ETensorType::UNDEFINED
@ UNDEFINED
TMVA::Experimental::SOFIE::ETensorType::INT64
@ INT64
TMVA::Experimental::SOFIE::ETensorType::INT32
@ INT32
TMVA::Experimental::SOFIE::ETensorType::FLOAT
@ FLOAT
TMVA::Experimental::SOFIE::ParserFuncSignature
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
Definition
RModelParser_ONNX.hxx:22
TMVA::Experimental::SOFIE::ConvertTypeToString
std::string ConvertTypeToString(ETensorType type)
Definition
SOFIE_common.cxx:44
TMVA
create variable transformations
Definition
GeneticMinimizer.h:22
tmva
sofie_parsers
src
ParseGather.cxx
ROOT v6-32 - Reference Guide Generated on Fri Sep 12 2025 04:31:12 (GVA Time) using Doxygen 1.10.0