Logo ROOT  
Reference Guide
TMVACrossValidationApplication.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides an example of how to use TMVA for k-folds cross
5/// evaluation in application.
6///
7/// This requires that CrossValidation was run with a deterministic split, such
8/// as `"...:splitExpr=int([eventID])%int([numFolds]):..."`.
9///
10/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
11/// - Package : TMVA
12/// - Root Macro: TMVACrossValidationApplication
13///
14/// \macro_output
15/// \macro_code
16/// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
17
18#include <cstdlib>
19#include <iostream>
20#include <map>
21#include <string>
22
23#include "TChain.h"
24#include "TFile.h"
25#include "TTree.h"
26#include "TString.h"
27#include "TObjString.h"
28#include "TSystem.h"
29#include "TROOT.h"
30
31#include "TMVA/Factory.h"
32#include "TMVA/DataLoader.h"
33#include "TMVA/Tools.h"
34#include "TMVA/TMVAGui.h"
35
36// Helper function to load data into TTrees.
37TTree *fillTree(TTree * tree, Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
38{
39 TRandom3 rng(seed);
40 Float_t x = 0;
41 Float_t y = 0;
42 Int_t eventID = 0;
43
44 tree->SetBranchAddress("x", &x);
45 tree->SetBranchAddress("y", &y);
46 tree->SetBranchAddress("eventID", &eventID);
47
48 for (Int_t n = 0; n < nPoints; ++n) {
49 x = rng.Gaus(offset, scale);
50 y = rng.Gaus(offset, scale);
51
52 // For our simple example it is enough that the id's are uniformly
53 // distributed and independent of the data.
54 ++eventID;
55
56 tree->Fill();
57 }
58
59 // Important: Disconnects the tree from the memory locations of x and y.
60 tree->ResetBranchAddresses();
61 return tree;
62}
63
64int TMVACrossValidationApplication()
65{
66 // This loads the library
68
69 // Set up the TMVA::Reader
70 TMVA::Reader *reader = new TMVA::Reader("!Color:!Silent:!V");
71
72 Float_t x;
73 Float_t y;
74 Int_t eventID;
75
76 reader->AddVariable("x", &x);
77 reader->AddVariable("y", &y);
78 reader->AddSpectator("eventID", &eventID);
79
80 // Book the serialised methods
81 TString jobname("TMVACrossEvaluation");
82 {
83 TString methodName = "BDTG";
84 TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
85
86 Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
87 if (weightfileExists) {
88 reader->BookMVA(methodName, weightfile);
89 } else {
90 std::cout << "Weightfile for method " << methodName << " not found."
91 " Did you run TMVACrossValidation with a specified"
92 " splitExpr?" << std::endl;
93 exit(0);
94 }
95
96 }
97 {
98 TString methodName = "Fisher";
99 TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
100
101 Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
102 if (weightfileExists) {
103 reader->BookMVA(methodName, weightfile);
104 } else {
105 std::cout << "Weightfile for method " << methodName << " not found."
106 " Did you run TMVACrossValidation with a specified"
107 " splitExpr?" << std::endl;
108 exit(0);
109 }
110 }
111
112 // Load data
113 TTree *tree = new TTree();
114 tree->Branch("x", &x, "x/F");
115 tree->Branch("y", &y, "y/F");
116 tree->Branch("eventID", &eventID, "eventID/I");
117
118 fillTree(tree, 1000, 1.0, 1.0, 100);
119 fillTree(tree, 1000, -1.0, 1.0, 101);
120 tree->SetBranchAddress("x", &x);
121 tree->SetBranchAddress("y", &y);
122 tree->SetBranchAddress("eventID", &eventID);
123
124 // Prepare histograms
125 Int_t nbin = 100;
126 TH1F histBDTG{"BDTG", "BDTG", nbin, -1, 1};
127 TH1F histFisher{"Fisher", "Fisher", nbin, -1, 1};
128
129 // Evaluate classifiers
130 for (Long64_t ievt = 0; ievt < tree->GetEntries(); ievt++) {
131 tree->GetEntry(ievt);
132
133 Double_t valBDTG = reader->EvaluateMVA("BDTG");
134 Double_t valFisher = reader->EvaluateMVA("Fisher");
135
136 histBDTG.Fill(valBDTG);
137 histFisher.Fill(valFisher);
138 }
139
140 tree->ResetBranchAddresses();
141 delete tree;
142
143 { // Write histograms to output file
144 TFile *target = new TFile("TMVACrossEvaluationApp.root", "RECREATE");
145 histBDTG.Write();
146 histFisher.Write();
147 target->Close();
148 delete target;
149 }
150
151 delete reader;
152
153 return 0;
154}
155
156//
157// This is used if the macro is compiled. If run through ROOT with
158// `root -l -b -q MACRO.C` or similar it is unused.
159//
160int main(int argc, char **argv)
161{
162 TMVACrossValidationApplication();
163}
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
long long Long64_t
Definition: RtypesCore.h:69
float Float_t
Definition: RtypesCore.h:53
R__EXTERN TSystem * gSystem
Definition: TSystem.h:560
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:856
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
Double_t EvaluateMVA(const std::vector< Float_t > &, const TString &methodTag, Double_t aux=0)
Evaluate a std::vector<float> of input data for a given method The parameter aux is obligatory for th...
Definition: Reader.cxx:473
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:373
void AddSpectator(const TString &expression, Float_t *)
Add a float spectator or expression to the reader.
Definition: Reader.cxx:326
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:308
static Tools & Instance()
Definition: Tools.cxx:75
Random number generator class based on M.
Definition: TRandom3.h:27
Basic string class.
Definition: TString.h:131
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1287
A TTree represents a columnar dataset.
Definition: TTree.h:72
int main(int argc, char **argv)
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
Definition: tree.py:1