ROOT
v6-32
Reference Guide
Loading...
Searching...
No Matches
ParseMatMul.cxx
Go to the documentation of this file.
1
#include "
TMVA/RModelParser_ONNX.hxx
"
2
#include "
TMVA/ROperator_Gemm.hxx
"
3
#include "onnx_proto3.pb.h"
4
5
namespace
TMVA
{
6
namespace
Experimental {
7
namespace
SOFIE {
8
9
ParserFuncSignature
ParseMatMul
= [](
RModelParser_ONNX
&parser,
const
onnx::NodeProto &
matmulnode
) {
10
ETensorType
input_type
=
ETensorType::UNDEFINED
;
11
12
// check input type - only first input from MatMul
13
auto
input_name
=
matmulnode
.input(0);
14
if
(parser.
IsRegisteredTensorType
(
input_name
)) {
15
input_type
= parser.
GetTensorType
(
input_name
);
16
}
else
{
17
throw
std::runtime_error(
"TMVA::SOFIE ONNX Parser MatMul op has input tensor "
+
input_name
+
18
" but its type is not yet registered"
);
19
}
20
21
std::unique_ptr<ROperator>
op
;
22
23
// for MatMul there is no alpha and beta : use alpha=1 and beta=0
24
float
attr_alpha
= 1.0;
25
float
attr_beta
= 0.0;
26
int_t
attr_transA
= 0;
27
int_t
attr_transB
= 0;
28
29
switch
(
input_type
) {
30
case
ETensorType::FLOAT
:
31
op
.reset(
new
ROperator_Gemm<float>
(
attr_alpha
,
attr_beta
,
attr_transA
,
attr_transB
,
matmulnode
.input(0),
32
matmulnode
.input(1),
matmulnode
.output(0)));
33
break
;
34
default
:
35
throw
std::runtime_error(
36
"TMVA::SOFIE - Unsupported - Operator for fusing MatMul and Add to Gemm does not yet support input type "
+
37
std::to_string(
static_cast<
int
>
(
input_type
)));
38
}
39
40
std::string
output_name
=
matmulnode
.output(0);
41
if
(!parser.
IsRegisteredTensorType
(
output_name
)) {
42
parser.
RegisterTensorType
(
output_name
,
input_type
);
43
}
44
45
return
op
;
46
};
47
48
}
// namespace SOFIE
49
}
// namespace Experimental
50
}
// namespace TMVA
RModelParser_ONNX.hxx
ROperator_Gemm.hxx
TRangeDynCast
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Definition
TCollection.h:358
ROOT::Detail::TRangeCast
Definition
TCollection.h:311
TMVA::Experimental::SOFIE::RModelParser_ONNX
Definition
RModelParser_ONNX.hxx:27
TMVA::Experimental::SOFIE::RModelParser_ONNX::IsRegisteredTensorType
bool IsRegisteredTensorType(const std::string &)
Definition
RModelParser_ONNX.cxx:184
TMVA::Experimental::SOFIE::RModelParser_ONNX::RegisterTensorType
void RegisterTensorType(const std::string &, ETensorType)
Definition
RModelParser_ONNX.cxx:179
TMVA::Experimental::SOFIE::RModelParser_ONNX::GetTensorType
ETensorType GetTensorType(const std::string &name)
Definition
RModelParser_ONNX.cxx:189
TMVA::Experimental::SOFIE::ETensorType
ETensorType
Definition
SOFIE_common.hxx:25
TMVA::Experimental::SOFIE::ETensorType::UNDEFINED
@ UNDEFINED
TMVA::Experimental::SOFIE::ETensorType::FLOAT
@ FLOAT
TMVA::Experimental::SOFIE::ParserFuncSignature
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
Definition
RModelParser_ONNX.hxx:22
TMVA::Experimental::SOFIE::ParseMatMul
ParserFuncSignature ParseMatMul
Definition
ParseMatMul.cxx:9
TMVA::Experimental::SOFIE::int_t
std::int64_t int_t
Definition
SOFIE_common.hxx:30
TMVA
create variable transformations
Definition
GeneticMinimizer.h:22
tmva
sofie_parsers
src
ParseMatMul.cxx
ROOT v6-32 - Reference Guide Generated on Fri Jan 24 2025 04:12:01 (GVA Time) using Doxygen 1.10.0