1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This tutorial shows how to apply with the modern interfaces models saved in
5/// TMVA XML files.
7/// \macro_code
8/// \macro_output
10/// \date July 2019
11/// \author Stefan Wunsch
13using namespace TMVA::Experimental;
15void train(const std::string &filename)
17 // Create factory
18 auto output = TFile::Open("TMVARR.root", "RECREATE");
19 auto factory = new TMVA::Factory("tmva003",
20 output, "!V:!DrawProgressBar:AnalysisType=Classification");
22 // Open trees with signal and background events
23 auto data = TFile::Open(filename.c_str());
24 auto signal = (TTree *)data->Get("TreeS");
25 auto background = (TTree *)data->Get("TreeB");
27 // Add variables and register the trees with the dataloader
28 auto dataloader = new TMVA::DataLoader("tmva003_BDT");
29 const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
30 for (const auto &var : variables) {
31 dataloader->AddVariable(var);
32 }
33 dataloader->AddSignalTree(signal, 1.0);
34 dataloader->AddBackgroundTree(background, 1.0);
35 dataloader->PrepareTrainingAndTestTree("", "");
37 // Train a TMVA method
38 factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=300:MaxDepth=2");
39 factory->TrainAllMethods();
42void tmva003_RReader()
44 // First, let's train a model with TMVA.
45 const std::string filename = "http://root.cern/files/tmva_class_example.root";
46 train(filename);
48 // Next, we load the model from the TMVA XML file.
49 RReader model("tmva003_BDT/weights/tmva003_BDT.weights.xml");
51 // In case you need a reminder of the names and order of the variables during
52 // training, you can ask the model for it.
53 auto variables = model.GetVariableNames();
55 // The model can now be applied in different scenarios:
56 // 1) Event-by-event inference
57 // 2) Batch inference on data of multiple events
58 // 3) Inference as part of an RDataFrame graph
60 // 1) Event-by-event inference
61 // The event-by-event inference takes the values of the variables as a std::vector<float>.
62 // Note that the return value is as well a std::vector<float> since the reader
63 // is also capable to process models with multiple outputs.
64 auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});
65 std::cout << "Single-event inference: " << prediction[0] << "\n\n";
67 // 2) Batch inference on data of multiple events
68 // For batch inference, the data needs to be structured as a matrix. For this
69 // purpose, TMVA makes use of the RTensor class. For convenience, we use RDataFrame
70 // and the AsTensor utility to make the read-out from the ROOT file.
71 ROOT::RDataFrame df("TreeS", filename);
72 auto df2 = df.Range(3); // Read only a small subset of the dataset
73 auto x = AsTensor<float>(df2, variables);
74 auto y = model.Compute(x);
76 std::cout << "RTensor input for inference on data of multiple events:\n" << x << "\n\n";
77 std::cout << "Prediction performed on multiple events: " << y << "\n\n";
79 // 3) Perform inference as part of an RDataFrame graph
80 // We write a small lambda function that performs for us the inference on
81 // a dataframe to omit code duplication.
82 auto make_histo = [&](const std::string &treename) {
83 ROOT::RDataFrame df(treename, filename);
84 auto df2 = df.Define("y", Compute<4, float>(model), variables);
85 return df2.Histo1D({treename.c_str(), ";BDT score;N_{Events}", 30, -0.5, 0.5}, "y");
86 };
88 auto sig = make_histo("TreeS");
89 auto bkg = make_histo("TreeB");
91 // Make plot
93 auto c = new TCanvas("", "", 800, 800);
95 sig->SetLineColor(kRed);
96 bkg->SetLineColor(kBlue);
97 sig->SetLineWidth(2);
98 bkg->SetLineWidth(2);
99 bkg->Draw("HIST");
100 sig->Draw("HIST SAME");
102 TLegend legend(0.7, 0.7, 0.89, 0.89);
103 legend.SetBorderSize(0);
104 legend.AddEntry("TreeS", "Signal", "l");
105 legend.AddEntry("TreeB", "Background", "l");
106 legend.Draw();
108 c->DrawClone();
