Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_Inference.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_tmva
3### \notebook -nodraw
4### This macro provides an example of using a trained model with Keras
5### and make inference using SOFIE directly from Numpy
6### This macro uses as input a Keras model generated with the
7### TMVA_Higgs_Classification.C tutorial
8### You need to run that macro before this one.
9### In this case we are parsing the input file and then run the inference in the same
10### macro making use of the ROOT JITing capability
11###
12###
13### \macro_code
14### \macro_output
15### \author Lorenzo Moneta
16
17import ROOT
18import numpy as np
19
20
21ROOT.TMVA.PyMethodBase.PyInitialize()
22
23
24# check if the input file exists
25modelFile = "Higgs_trained_model.h5"
26if (ROOT.gSystem.AccessPathName(modelFile)) :
27 ROOT.Info("TMVA_SOFIE_RDataFrame","You need to run TMVA_Higgs_Classification.C to generate the Keras trained model")
28 exit()
29
30
31# parse the input Keras model into RModel object
32model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse(modelFile)
33
34generatedHeaderFile = modelFile.replace(".h5",".hxx")
35print("Generating inference code for the Keras model from ",modelFile,"in the header ", generatedHeaderFile)
36#Generating inference code
37model.Generate()
38model.OutputGenerated(generatedHeaderFile)
39model.PrintGenerated()
40
41# now compile using ROOT JIT trained model
42modelName = modelFile.replace(".h5","")
43print("compiling SOFIE model ", modelName)
44ROOT.gInterpreter.Declare('#include "' + generatedHeaderFile + '"')
45
46
47generatedHeaderFile = modelFile.replace(".h5",".hxx")
48print("Generating inference code for the Keras model from ",modelFile,"in the header ", generatedHeaderFile)
49#Generating inference
50
51inputFileName = "Higgs_data.root"
52inputFile = "http://root.cern.ch/files/" + inputFileName
53
54
55
56
57
58# make SOFIE inference on signal data
59
60df1 = ROOT.RDataFrame("sig_tree", inputFile)
61sigData = df1.AsNumpy(columns=['m_jj', 'm_jjj', 'm_lv', 'm_jlv', 'm_bb', 'm_wbb', 'm_wwbb'])
62#print(sigData)
63
64# stack all the 7 numpy array in a single array (nevents x nvars)
65xsig = np.column_stack(list(sigData.values()))
66dataset_size = xsig.shape[0]
67print("size of data", dataset_size)
68
69#instantiate SOFIE session class
70session = ROOT.TMVA_SOFIE_Higgs_trained_model.Session()
71
72hs = ROOT.TH1D("hs","Signal result",100,0,1)
73for i in range(0,dataset_size):
74 result = session.infer(xsig[i,:])
75 hs.Fill(result[0])
76
77
78# make SOFIE inference on background data
79df2 = ROOT.RDataFrame("bkg_tree", inputFile)
80bkgData = df2.AsNumpy(columns=['m_jj', 'm_jjj', 'm_lv', 'm_jlv', 'm_bb', 'm_wbb', 'm_wwbb'])
81
82xbkg = np.column_stack(list(bkgData.values()))
83dataset_size = xbkg.shape[0]
84
85hb = ROOT.TH1D("hb","Background result",100,0,1)
86for i in range(0,dataset_size):
87 result = session.infer(xbkg[i,:])
88 hb.Fill(result[0])
89
90
91c1 = ROOT.TCanvas()
92ROOT.gStyle.SetOptStat(0)
93hs.SetLineColor(ROOT.kRed)
94hs.Draw()
95hb.SetLineColor(ROOT.kBlue)
96hb.Draw("SAME")
97c1.BuildLegend()
98c1.Draw()
99
100
101print("Number of signal entries",hs.GetEntries())
102print("Number of background entries",hb.GetEntries())
103
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...