Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseRandom.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseRandom = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
10
12 auto op_type = nodeproto.op_type();
13 if (op_type == "RandomNormal" || op_type == "RandomNormalLike")
15
16
17 ETensorType input_type = ETensorType::FLOAT; // default value
18 std::string input_name;
19 // case of NormalLike and UniformLike , type is given by the input
20 if (nodeproto.input_size() > 0) {
21 input_name = nodeproto.input(0);
24 } else {
25 throw std::runtime_error("TMVA::SOFIE ONNX Parser Randomr op has input tensor" + input_name +
26 " but its type is not yet registered");
27 }
28 }
29 // get attributes
30 float seed = 0;
31 std::map<std::string, float> paramMap;
32 std::vector<size_t> shape;
33 for (int i = 0; i < nodeproto.attribute_size(); i++) {
34 std::string attribute_name = nodeproto.attribute(i).name();
35 auto attr_type = nodeproto.attribute(i).type();
36 if (attribute_name == "dtype")
37 input_type = static_cast<ETensorType>(nodeproto.attribute(i).i());
38 else if (attribute_name == "seed") {
39 if (attr_type == onnx::AttributeProto::FLOAT )
40 seed = nodeproto.attribute(i).f();
41 else if (attr_type == onnx::AttributeProto::INT)
42 seed = nodeproto.attribute(i).i();
43 else
44 throw std::runtime_error("TMVA::SOFIE ONNX Parser Random op has invalid type for attribute seed");
45 }
46 else if (attribute_name == "shape") {
47 if (attr_type != onnx::AttributeProto::INTS)
48 throw std::runtime_error("TMVA::SOFIE ONNX Parser Random op has invalid type for attribute shape");
49 shape = std::vector<size_t>(nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end());
50 }
51 else {
52 float value = 0;
53 if (attr_type == onnx::AttributeProto::FLOAT)
54 value = nodeproto.attribute(i).f();
55 else if (attr_type == onnx::AttributeProto::INT)
56 value = nodeproto.attribute(i).i();
57 else
58 throw std::runtime_error("TMVA::SOFIE ONNX Parser Random op has invalid type for attribute " + attribute_name);
60 }
61 }
62
63 std::string output_name = nodeproto.output(0);
64
65 std::unique_ptr<ROperator> op(new ROperator_Random(opMode, input_type, input_name, output_name, shape, paramMap, seed));
66
69 }
70
71 return op;
72};
73
74} // namespace SOFIE
75} // namespace Experimental
76} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
const_iterator begin() const
const_iterator end() const
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseRandom
create variable transformations