Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_RSofieReader.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides an example of using a trained model with Keras
5/// and make inference using SOFIE with the RSofieReader class
6/// This macro uses as input a Keras model generated with the
7/// TMVA_Higgs_Classification.C tutorial
8/// You need to run that macro before to generate the trained Keras model
9///
10///
11/// Execute in this order:
12/// ```
13/// root TMVA_Higgs_Classification.C
14/// root TMVA_SOFIE_RSofieReader.C
15/// ```
16///
17/// \macro_code
18/// \macro_output
19/// \author Lorenzo Moneta
20
21using namespace TMVA::Experimental;
22
23void TMVA_SOFIE_RSofieReader(){
24
25 RSofieReader model("Higgs_trained_model.h5");
26 // for debugging
27 //RSofieReader model("Higgs_trained_model.h5", {}, true);
28
29 // the input shape for this model is a tensor with shape (1,7)
30
31 std::vector<float> input = {0.1,0.2,0.3,0.4,0.5,0.6,0.7};
32
33 // predict model on a single event (takes a std::vector<float>)
34
35 auto output = model.Compute(input);
36
37 std::cout << "Event prediction = " << output[0] << std::endl;
38
39 // predict model now on a input file using RDataFrame
40
41 std::string inputFileName = "Higgs_data.root";
42 std::string inputFile = "http://root.cern.ch/files/" + inputFileName;
43
44
45 ROOT::RDataFrame df1("sig_tree", inputFile);
46
47 auto h1 = df1.Define("DNN_Values", Compute<7, float>(model),
48 {"m_jj", "m_jjj", "m_lv", "m_jlv", "m_bb", "m_wbb", "m_wwbb"})
49 .Define("y","DNN_Values[0]")
50 .Histo1D({"h_sig", "", 100, 0, 1}, "y");
51
52 ROOT::RDataFrame df2("bkg_tree", inputFile);
53 auto h2 = df2.Define("DNN_Values", Compute<7, float>(model),
54 {"m_jj", "m_jjj", "m_lv", "m_jlv", "m_bb", "m_wbb", "m_wwbb"})
55 .Define("y","DNN_Values[0]")
56 .Histo1D({"h_bkg", "", 100, 0, 1}, "y");
57
59 h2->SetLineColor(kBlue);
60
61 auto c1 = new TCanvas();
63
64 h2->DrawClone();
65 h1->DrawClone("SAME");
66 c1->BuildLegend();
67
68}
@ kRed
Definition Rtypes.h:66
@ kBlue
Definition Rtypes.h:66
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
R__EXTERN TStyle * gStyle
Definition TStyle.h:436
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:40
The Canvas class.
Definition TCanvas.h:23
TMVA::RSofieReader class for reading external Machine Learning models in ONNX files,...
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad with: gROOT->SetSelectedPad(c1).
Definition TObject.cxx:299
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition TStyle.cxx:1640
return c1
Definition legend1.C:41
TH1F * h1
Definition legend1.C:5
static void output()