Logo ROOT  
Reference Guide
TMVA_SOFIE_ONNX.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides a simple example for the parsing of ONNX files into
5/// RModel object and further generating the .hxx header files for inference.
6///
7/// \macro_code
8/// \macro_output
9/// \author Sanjiban Sengupta
10
11using namespace TMVA::Experimental;
12
13void TMVA_SOFIE_ONNX(){
14 //Creating parser object to parse ONNX files
16 SOFIE::RModel model = Parser.Parse("../../tmva/sofie/test/input_models/Linear_16.onnx");
17
18 //Generating inference code
19 model.Generate();
20 model.OutputGenerated("Linear_16.hxx");
21
22 //Printing required input tensors
24
25 //Printing initialized tensors (weights)
26 std::cout<<"\n\n";
28
29 //Printing intermediate tensors
30 std::cout<<"\n\n";
32
33 //Checking if tensor already exist in model
34 std::cout<<"\n\nTensor \"16weight\" already exist: "<<std::boolalpha<<model.CheckIfTensorAlreadyExist("16weight")<<"\n\n";
35 std::vector<size_t> tensorShape = model.GetTensorShape("16weight");
36 std::cout<<"Shape of tensor \"16weight\": ";
37 for(auto& it:tensorShape){
38 std::cout<<it<<",";
39 }
40 std::cout<<"\n\nData type of tensor \"16weight\": ";
41 SOFIE::ETensorType tensorType = model.GetTensorType("16weight");
42 std::cout<<SOFIE::ConvertTypeToString(tensorType);
43
44 //Printing generated inference code
45 std::cout<<"\n\n";
46 model.PrintGenerated();
47}
const ETensorType & GetTensorType(std::string name)
Definition: RModel.cxx:68
void Generate(bool useSession=true, bool useWeightFile=true)
Definition: RModel.cxx:173
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition: RModel.cxx:89
void OutputGenerated(std::string filename="")
Definition: RModel.cxx:516
const std::vector< size_t > & GetTensorShape(std::string name)
Definition: RModel.cxx:47
std::string ConvertTypeToString(ETensorType type)