ROOT
Version v6.34
master
v6.32
v6.30
v6.28
v6.26
v6.24
v6.22
v6.20
v6.18
v6.16
v6.14
v6.12
v6.10
v6.08
v6.06
Reference Guide
►
ROOT
•
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Properties
Friends
Macros
Modules
Pages
Loading...
Searching...
No Matches
TMVA_SOFIE_Inference.py
Go to the documentation of this file.
1
### \file
2
### \ingroup tutorial_tmva
3
### \notebook -nodraw
4
### This macro provides an example of using a trained model with Keras
5
### and make inference using SOFIE directly from Numpy
6
### This macro uses as input a Keras model generated with the
7
### TMVA_Higgs_Classification.C tutorial
8
### You need to run that macro before this one.
9
### In this case we are parsing the input file and then run the inference in the same
10
### macro making use of the ROOT JITing capability
11
###
12
###
13
### \macro_code
14
### \macro_output
15
### \author Lorenzo Moneta
16
17
import
ROOT
18
import
numpy
as
np
19
20
21
ROOT.TMVA.PyMethodBase.PyInitialize
()
22
23
24
# check if the input file exists
25
modelFile =
"Higgs_trained_model.h5"
26
if
(
ROOT.gSystem.AccessPathName
(modelFile)) :
27
ROOT.Info
(
"TMVA_SOFIE_RDataFrame"
,
"You need to run TMVA_Higgs_Classification.C to generate the Keras trained model"
)
28
exit
()
29
30
31
# parse the input Keras model into RModel object
32
model =
ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse
(modelFile)
33
34
generatedHeaderFile =
modelFile.replace
(
".h5"
,
".hxx"
)
35
print(
"Generating inference code for the Keras model from "
,modelFile,
"in the header "
, generatedHeaderFile)
36
#Generating inference code
37
model.Generate
()
38
model.OutputGenerated
(generatedHeaderFile)
39
model.PrintGenerated
()
40
41
# now compile using ROOT JIT trained model
42
modelName =
modelFile.replace
(
".h5"
,
""
)
43
print(
"compiling SOFIE model "
, modelName)
44
ROOT.gInterpreter.Declare
(
'#include "'
+ generatedHeaderFile +
'"'
)
45
46
47
generatedHeaderFile =
modelFile.replace
(
".h5"
,
".hxx"
)
48
print(
"Generating inference code for the Keras model from "
,modelFile,
"in the header "
, generatedHeaderFile)
49
#Generating inference
50
51
inputFileName =
"Higgs_data.root"
52
inputFile = str(
ROOT.gROOT.GetTutorialDir
()) +
"/tmva/data/"
+ inputFileName
53
54
55
56
57
58
# make SOFIE inference on signal data
59
60
df1 =
ROOT.RDataFrame
(
"sig_tree"
, inputFile)
61
sigData =
df1.AsNumpy
(columns=[
'm_jj'
,
'm_jjj'
,
'm_lv'
,
'm_jlv'
,
'm_bb'
,
'm_wbb'
,
'm_wwbb'
])
62
#print(sigData)
63
64
# stack all the 7 numpy array in a single array (nevents x nvars)
65
xsig =
np.column_stack
(list(
sigData.values
()))
66
dataset_size =
xsig.shape
[0]
67
print(
"size of data"
, dataset_size)
68
69
#instantiate SOFIE session class
70
session =
ROOT.TMVA_SOFIE_Higgs_trained_model.Session
()
71
72
hs =
ROOT.TH1D
(
"hs"
,
"Signal result"
,100,0,1)
73
for
i
in
range
(0,dataset_size):
74
result =
session.infer
(xsig[i,:])
75
hs.Fill
(result[0])
76
77
78
# make SOFIE inference on background data
79
df2 =
ROOT.RDataFrame
(
"bkg_tree"
, inputFile)
80
bkgData =
df2.AsNumpy
(columns=[
'm_jj'
,
'm_jjj'
,
'm_lv'
,
'm_jlv'
,
'm_bb'
,
'm_wbb'
,
'm_wwbb'
])
81
82
xbkg =
np.column_stack
(list(
bkgData.values
()))
83
dataset_size =
xbkg.shape
[0]
84
85
hb =
ROOT.TH1D
(
"hb"
,
"Background result"
,100,0,1)
86
for
i
in
range
(0,dataset_size):
87
result =
session.infer
(xbkg[i,:])
88
hb.Fill
(result[0])
89
90
91
c1 =
ROOT.TCanvas
()
92
ROOT.gStyle.SetOptStat
(0)
93
hs.SetLineColor
(
ROOT.kRed
)
94
hs.Draw
()
95
hb.SetLineColor
(
ROOT.kBlue
)
96
hb.Draw
(
"SAME"
)
97
c1.BuildLegend
()
98
c1.Draw
()
99
100
101
print(
"Number of signal entries"
,
hs.GetEntries
())
102
print(
"Number of background entries"
,
hb.GetEntries
())
103
TRangeDynCast
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Definition
TCollection.h:358
ROOT::Detail::TRangeCast
Definition
TCollection.h:311
ROOT::RDataFrame
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
Definition
RDataFrame.hxx:41
tutorials
tmva
TMVA_SOFIE_Inference.py
ROOT v6-34 - Reference Guide Generated on Tue Apr 1 2025 05:42:14 (GVA Time) using Doxygen 1.10.0