Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_RSofieReader.C File Reference

Detailed Description

View in nbviewer Open in SWAN
This macro provides an example of using a trained model with Keras and make inference using SOFIE with the RSofieReader class This macro uses as input a Keras model generated with the TMVA_Higgs_Classification.C tutorial You need to run that macro before to generate the trained Keras model

Execute in this order:

using namespace TMVA::Experimental;
void TMVA_SOFIE_RSofieReader(){
RSofieReader model("Higgs_trained_model.h5");
// for debugging
//RSofieReader model("Higgs_trained_model.h5", {}, true);
// the input shape for this model is a tensor with shape (1,7)
std::vector<float> input = {0.1,0.2,0.3,0.4,0.5,0.6,0.7};
// predict model on a single event (takes a std::vector<float>)
auto output = model.Compute(input);
std::cout << "Event prediction = " << output[0] << std::endl;
// predict model now on a input file using RDataFrame
std::string inputFileName = "Higgs_data.root";
std::string inputFile = "http://root.cern.ch/files/" + inputFileName;
ROOT::RDataFrame df1("sig_tree", inputFile);
auto h1 = df1.Define("DNN_Values", Compute<7, float>(model),
{"m_jj", "m_jjj", "m_lv", "m_jlv", "m_bb", "m_wbb", "m_wwbb"})
.Define("y","DNN_Values[0]")
.Histo1D({"h_sig", "", 100, 0, 1}, "y");
ROOT::RDataFrame df2("bkg_tree", inputFile);
auto h2 = df2.Define("DNN_Values", Compute<7, float>(model),
{"m_jj", "m_jjj", "m_lv", "m_jlv", "m_bb", "m_wbb", "m_wwbb"})
.Define("y","DNN_Values[0]")
.Histo1D({"h_bkg", "", 100, 0, 1}, "y");
h2->SetLineColor(kBlue);
auto c1 = new TCanvas();
h2->DrawClone();
h1->DrawClone("SAME");
c1->BuildLegend();
}
@ kRed
Definition Rtypes.h:66
@ kBlue
Definition Rtypes.h:66
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
R__EXTERN TStyle * gStyle
Definition TStyle.h:433
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:40
The Canvas class.
Definition TCanvas.h:23
TMVA::RSofieReader class for reading external Machine Learning models in ONNX files,...
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad with: gROOT->SetSelectedPad(c1).
Definition TObject.cxx:299
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition TStyle.cxx:1636
return c1
Definition legend1.C:41
TH1F * h1
Definition legend1.C:5
static void output()
Model has not a defined batch size, assume is 1 - input shape for tensor dense_input : { 1 , 7 }
Event prediction = 0.209668
Author
Lorenzo Moneta

Definition in file TMVA_SOFIE_RSofieReader.C.