Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_RDataFrame.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_tmva
3### \notebook -nodraw
4### Example of inference with SOFIE and RDataFrame, of a model trained with Keras.
5### First, generate the input model by running `TMVA_Higgs_Classification.C`.
6###
7### This tutorial parses the input model and runs the inference using ROOT's JITing capability.
8###
9### \macro_code
10### \macro_output
11### \author Lorenzo Moneta
12
13import ROOT
14from os.path import exists
15
16ROOT.TMVA.PyMethodBase.PyInitialize()
17
18
19# check if the input file exists
20modelFile = "Higgs_trained_model.h5"
21modelName = "Higgs_trained_model";
22
23if not exists(modelFile):
24 raise FileNotFoundError("You need to run TMVA_Higgs_Classification.C to generate the Keras trained model")
25
26# parse the input Keras model into RModel object
27model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse(modelFile)
28
29# generating inference code
30model.Generate()
31model.OutputGenerated("Higgs_trained_model_generated.hxx")
32model.PrintGenerated()
33
34# compile using ROOT JIT trained model
35print("compiling SOFIE model and functor....")
36ROOT.gInterpreter.Declare('#include "Higgs_trained_model_generated.hxx"')
37ROOT.gInterpreter.Declare('auto sofie_functor = TMVA::Experimental::SofieFunctor<7,TMVA_SOFIE_'+modelName+'::Session>(0,"Higgs_trained_model_generated.dat");')
38
39# run inference over input data
40inputFile = "http://root.cern/files/Higgs_data.root"
41df1 = ROOT.RDataFrame("sig_tree", inputFile)
42h1 = df1.Define("DNN_Value", "sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)").Histo1D(("h_sig", "", 100, 0, 1),"DNN_Value")
43
44df2 = ROOT.RDataFrame("bkg_tree", inputFile)
45h2 = df2.Define("DNN_Value", "sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)").Histo1D(("h_bkg", "", 100, 0, 1),"DNN_Value")
46
47# run over the input data once, combining both RDataFrame graphs.
48ROOT.RDF.RunGraphs([h1, h2]);
49
50print("Number of signal entries",h1.GetEntries())
51print("Number of background entries",h2.GetEntries())
52
53h1.SetLineColor(ROOT.kRed)
54h2.SetLineColor(ROOT.kBlue)
55
56c1 = ROOT.TCanvas()
57ROOT.gStyle.SetOptStat(0)
58
59h2.DrawClone()
60h1.DrawClone("SAME")
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
unsigned int RunGraphs(std::vector< RResultHandle > handles)
Trigger the event loop of multiple RDataFrames concurrently.