ROOT
master
Reference Guide
Loading...
Searching...
No Matches
ParseTopK.cxx
Go to the documentation of this file.
1
#include "
TMVA/RModelParser_ONNX.hxx
"
2
#include "
TMVA/ROperator_TopK.hxx
"
3
#include "onnx_proto3.pb.h"
4
5
namespace
TMVA
{
6
namespace
Experimental {
7
namespace
SOFIE {
8
9
ParserFuncSignature
ParseTopK
= [](
RModelParser_ONNX
&parser,
const
onnx::NodeProto &
nodeproto
) {
10
ETensorType
input_type
=
ETensorType::UNDEFINED
;
11
12
std::string
input_name
=
nodeproto
.input(0);
13
if
(parser.
IsRegisteredTensorType
(
input_name
)) {
14
input_type
= parser.
GetTensorType
(
input_name
);
15
}
else
{
16
throw
std::runtime_error(
"TMVA::SOFIE ONNX Parser TopK op has input tensor "
+
input_name
+
17
" but its type is not yet registered"
);
18
}
19
std::string
k_name
=
nodeproto
.input(1);
20
if
(!parser.
IsRegisteredTensorType
(
k_name
)) {
21
throw
std::runtime_error(
"TMVA::SOFIE ONNX Parser TopK op has input tensor "
+
k_name
+
22
" but its type is not yet registered"
);
23
}
24
25
std::unique_ptr<ROperator>
op
;
26
27
std::string
outputVal_name
=
nodeproto
.output(0);
28
std::string
outputInd_name
=
nodeproto
.output(1);
29
int
attr_axis
= -1;
30
int
attr_largest
= 1;
31
int
attr_sorted
= 1;
32
33
for
(
int_t
i = 0; i <
nodeproto
.attribute_size(); i++) {
34
std::string
attribute_name
=
nodeproto
.attribute(i).name();
35
if
(
attribute_name
==
"axis"
)
36
attr_axis
=
nodeproto
.attribute(i).i();
37
if
(
attribute_name
==
"largest"
)
38
attr_largest
=
nodeproto
.attribute(i).i();
39
if
(
attribute_name
==
"sorted"
)
40
attr_sorted
=
nodeproto
.attribute(i).i();
41
}
42
op
.reset(
new
ROperator_TopK<float>
(
attr_axis
,
attr_largest
,
attr_sorted
,
k_name
,
input_name
,
outputVal_name
,
outputInd_name
));
43
44
if
(!parser.
IsRegisteredTensorType
(
outputVal_name
)) {
45
parser.
RegisterTensorType
(
outputVal_name
,
input_type
);
46
}
47
if
(!parser.
IsRegisteredTensorType
(
outputInd_name
)) {
48
parser.
RegisterTensorType
(
outputInd_name
,
ETensorType::INT64
);
49
}
50
51
return
op
;
52
};
53
54
}
// namespace SOFIE
55
}
// namespace Experimental
56
}
// namespace TMVA
RModelParser_ONNX.hxx
ROperator_TopK.hxx
TRangeDynCast
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Definition
TCollection.h:358
ROOT::Detail::TRangeCast
Definition
TCollection.h:311
TMVA::Experimental::SOFIE::RModelParser_ONNX
Definition
RModelParser_ONNX.hxx:28
TMVA::Experimental::SOFIE::RModelParser_ONNX::IsRegisteredTensorType
bool IsRegisteredTensorType(const std::string &)
Definition
RModelParser_ONNX.cxx:269
TMVA::Experimental::SOFIE::RModelParser_ONNX::RegisterTensorType
void RegisterTensorType(const std::string &, ETensorType)
Definition
RModelParser_ONNX.cxx:264
TMVA::Experimental::SOFIE::RModelParser_ONNX::GetTensorType
ETensorType GetTensorType(const std::string &name)
Definition
RModelParser_ONNX.cxx:274
TMVA::Experimental::SOFIE::ETensorType
ETensorType
Definition
SOFIE_common.hxx:28
TMVA::Experimental::SOFIE::ETensorType::UNDEFINED
@ UNDEFINED
TMVA::Experimental::SOFIE::ETensorType::INT64
@ INT64
TMVA::Experimental::SOFIE::ParserFuncSignature
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
Definition
RModelParser_ONNX.hxx:23
TMVA::Experimental::SOFIE::ParseTopK
ParserFuncSignature ParseTopK
Definition
ParseTopK.cxx:9
TMVA::Experimental::SOFIE::int_t
std::int64_t int_t
Definition
SOFIE_common.hxx:55
TMVA
create variable transformations
Definition
GeneticMinimizer.h:22
tmva
sofie_parsers
src
ParseTopK.cxx
ROOT master - Reference Guide Generated on Tue Apr 22 2025 16:24:44 (GVA Time) using Doxygen 1.10.0