Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_RDataFrame_JIT.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook -nodraw
4/// This macro provides an example of using a trained model with Keras
5/// and make inference using SOFIE and RDataFrame
6/// This macro uses as input a Keras model generated with the
7/// TMVA_Higgs_Classification.C tutorial
8/// You need to run that macro before this one.
9/// In this case we are parsing the input file and then run the inference in the same
10/// macro making use of the ROOT JITing capability
11///
12///
13/// \macro_code
14/// \macro_output
15/// \author Lorenzo Moneta
16
17using namespace TMVA::Experimental;
18
19/// function to compile the generated model and the declaration of the SofieFunctor
20/// used by RDF.
21/// Assume that the model name as in the header file
22void CompileModelForRDF(const std::string & headerModelFile, unsigned int ninputs, unsigned int nslots=0) {
23
24 std::string modelName = headerModelFile.substr(0,headerModelFile.find(".hxx"));
25 std::string cmd = std::string("#include \"") + headerModelFile + std::string("\"");
26 auto ret = gInterpreter->Declare(cmd.c_str());
27 if (!ret)
28 throw std::runtime_error("Error compiling : " + cmd);
29 std::cout << "compiled : " << cmd << std::endl;
30
31 cmd = "auto sofie_functor = TMVA::Experimental::SofieFunctor<" + std::to_string(ninputs) + ",TMVA_SOFIE_" +
32 modelName + "::Session>(" + std::to_string(nslots) + ");";
33 ret = gInterpreter->Declare(cmd.c_str());
34 if (!ret)
35 throw std::runtime_error("Error compiling : " + cmd);
36 std::cout << "compiled : " << cmd << std::endl;
37 std::cout << "Model is ready to be evaluated" << std::endl;
38 return;
39}
40
41void TMVA_SOFIE_RDataFrame_JIT(std::string modelFile = "Higgs_trained_model.h5"){
42
43 // check if the input file exists
44 if (gSystem->AccessPathName(modelFile.c_str())) {
45 Info("TMVA_SOFIE_RDataFrame","You need to run TMVA_Higgs_Classification.C to generate the Keras trained model");
46 return;
47 }
48
49 // parse the input Keras model into RModel object
50 SOFIE::RModel model = SOFIE::PyKeras::Parse(modelFile);
51
52 std::string modelName = modelFile.substr(0,modelFile.find(".h5"));
53 std::string modelHeaderFile = modelName + std::string(".hxx");
54 //Generating inference code
55 model.Generate();
57 model.PrintGenerated();
58 // check that also weigh file exists
59 std::string modelWeightFile = modelName + std::string(".dat");
61 Error("TMVA_SOFIE_RDataFrame","Generated weight file is missing");
62 return;
63 }
64
65 // now compile using ROOT JIT trained model (see function above)
67
68 std::string inputFileName = "Higgs_data.root";
69 std::string inputFile = std::string{gROOT->GetTutorialDir()} + "/machine_learning/data/" + inputFileName;
70
71 ROOT::RDataFrame df1("sig_tree", inputFile);
72 auto h1 = df1.Define("DNN_Value", "sofie_functor(rdfslot_,m_jj, m_jjj, m_lv, m_jlv, m_bb, m_wbb, m_wwbb)")
73 .Histo1D({"h_sig", "", 100, 0, 1},"DNN_Value");
74
75 ROOT::RDataFrame df2("bkg_tree", inputFile);
77 .Histo1D({"h_bkg", "", 100, 0, 1},"DNN_Value");
78
80 h2->SetLineColor(kBlue);
81
82 auto c1 = new TCanvas();
84
85 h2->DrawClone();
86 h1->DrawClone("SAME");
87 c1->BuildLegend();
88
89
90}
@ kRed
Definition Rtypes.h:67
@ kBlue
Definition Rtypes.h:67
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition TError.cxx:241
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:208
#define gInterpreter
#define gROOT
Definition TROOT.h:411
R__EXTERN TStyle * gStyle
Definition TStyle.h:442
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:42
The Canvas class.
Definition TCanvas.h:23
void OutputGenerated(std::string filename="", bool append=false)
Definition RModel.cxx:1430
void Generate(std::underlying_type_t< Options > options, int batchSize=-1, long pos=0, bool verbose=false)
Definition RModel.cxx:1062
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad with: gROOT->SetSelectedPad(c1).
Definition TObject.cxx:318
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition TStyle.cxx:1641
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1307
return c1
Definition legend1.C:41
TH1F * h1
Definition legend1.C:5