Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVACrossValidationRegression.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides an example of how to use TMVA for k-folds cross
5/// evaluation.
6///
7/// As input data is used a toy-MC sample consisting of two gaussian
8/// distributions.
9///
10/// The output file "TMVARegCv.root" can be analysed with the use of dedicated
11/// macros (simply say: root -l <macro.C>), which can be conveniently
12/// invoked through a GUI that will appear at the end of the run of this macro.
13/// Launch the GUI via the command:
14///
15/// ```
16/// root -l -e 'TMVA::TMVAGui("TMVARegCv.root")'
17/// ```
18///
19/// ## Cross Evaluation
20/// Cross evaluation is a special case of k-folds cross validation where the
21/// splitting into k folds is computed deterministically. This ensures that the
22/// a given event will always end up in the same fold.
23///
24/// In addition all resulting classifiers are saved and can be applied to new
25/// data using `MethodCrossValidation`. One requirement for this to work is a
26/// splitting function that is evaluated for each event to determine into what
27/// fold it goes (for training/evaluation) or to what classifier (for
28/// application).
29///
30/// ## Split Expression
31/// Cross evaluation uses a deterministic split to partition the data into
32/// folds called the split expression. The expression can be any valid
33/// `TFormula` as long as all parts used are defined.
34///
35/// For each event the split expression is evaluated to a number and the event
36/// is put in the fold corresponding to that number.
37///
38/// It is recommended to always use `%int([NumFolds])` at the end of the
39/// expression.
40///
41/// The split expression has access to all spectators and variables defined in
42/// the dataloader. Additionally, the number of folds in the split can be
43/// accessed with `NumFolds` (or `numFolds`).
44///
45/// ### Example
46/// ```
47/// "int(fabs([eventID]))%int([NumFolds])"
48/// ```
49///
50/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51/// - Package : TMVA
52/// - Root Macro: TMVACrossValidationRegression
53///
54/// \macro_output
55/// \macro_code
56/// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57
58#include <cstdlib>
59#include <iostream>
60#include <map>
61#include <string>
62
63#include "TChain.h"
64#include "TFile.h"
65#include "TTree.h"
66#include "TString.h"
67#include "TObjString.h"
68#include "TSystem.h"
69#include "TROOT.h"
70
71#include "TMVA/Factory.h"
72#include "TMVA/DataLoader.h"
73#include "TMVA/Tools.h"
74#include "TMVA/TMVAGui.h"
76
77TFile * getDataFile(TString fname) {
78 TFile *input(0);
79
80 if (!gSystem->AccessPathName(fname)) {
81 input = TFile::Open(fname); // check if file in local directory exists
82 } else {
83 // if not: download from ROOT server
85 input = TFile::Open("http://root.cern/files/tmva_reg_example.root", "CACHEREAD");
86 }
87
88 if (!input) {
89 std::cout << "ERROR: could not open data file " << fname << std::endl;
90 exit(1);
91 }
92
93 return input;
94}
95
96int TMVACrossValidationRegression()
97{
98 // This loads the library
100
101 // --------------------------------------------------------------------------
102
103 // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
104 TString outfileName("TMVARegCv.root");
105 TFile * outputFile = TFile::Open(outfileName, "RECREATE");
106
107 TString infileName("./files/tmva_reg_example.root");
108 TFile * inputFile = getDataFile(infileName);
109
110 TMVA::DataLoader *dataloader=new TMVA::DataLoader("datasetcvreg");
111
112 dataloader->AddVariable("var1", "Variable 1", "units", 'F');
113 dataloader->AddVariable("var2", "Variable 2", "units", 'F');
114
115 // Add the variable carrying the regression target
116 dataloader->AddTarget("fvalue");
117
118 TTree * regTree = (TTree*)inputFile->Get("TreeR");
119 dataloader->AddRegressionTree(regTree, 1.0);
120
121 // Individual events can be weighted
122 // dataloader->SetWeightExpression("weight", "Regression");
123
124 std::cout << "--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl;
125
126 // Bypasses the normal splitting mechanism, CV uses a new system for this.
127 // Unfortunately the old system is unhappy if we leave the test set empty so
128 // we ensure that there is at least one event by placing the first event in
129 // it.
130 // You can with the selection cut place a global cut on the defined
131 // variables. Only events passing the cut will be using in training/testing.
132 // Example: `TCut selectionCut = "var1 < 1";`
133 TCut selectionCut = "";
134 dataloader->PrepareTrainingAndTestTree(selectionCut, "nTest_Regression=1"
135 ":SplitMode=Block"
136 ":NormMode=NumEvents"
137 ":!V");
138
139 // --------------------------------------------------------------------------
140
141 //
142 // This sets up a CrossValidation class (which wraps a TMVA::Factory
143 // internally) for 2-fold cross validation. The data will be split into the
144 // two folds randomly if `splitExpr` is `""`.
145 //
146 // One can also give a deterministic split using spectator variables. An
147 // example would be e.g. `"int(fabs([spec1]))%int([NumFolds])"`.
148 //
149 UInt_t numFolds = 2;
150 TString analysisType = "Regression";
151 TString splitExpr = "";
152
153 TString cvOptions = Form("!V"
154 ":!Silent"
155 ":ModelPersistence"
156 ":!FoldFileOutput"
157 ":AnalysisType=%s"
158 ":NumFolds=%i"
159 ":SplitExpr=%s",
160 analysisType.Data(), numFolds, splitExpr.Data());
161
162 TMVA::CrossValidation cv{"TMVACrossValidationRegression", dataloader, outputFile, cvOptions};
163
164 // --------------------------------------------------------------------------
165
166 //
167 // Books a method to use for evaluation
168 //
169 cv.BookMethod(TMVA::Types::kBDT, "BDTG",
170 "!H:!V:NTrees=500:BoostType=Grad:Shrinkage=0.1:"
171 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3");
172
173 // --------------------------------------------------------------------------
174
175 //
176 // Train, test and evaluate the booked methods.
177 // Evaluates the booked methods once for each fold and aggregates the result
178 // in the specified output file.
179 //
180 cv.Evaluate();
181
182 // --------------------------------------------------------------------------
183
184 //
185 // Save the output
186 //
187 outputFile->Close();
188
189 std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
190 std::cout << "==> TMVACrossValidationRegression is done!" << std::endl;
191
192 // --------------------------------------------------------------------------
193
194 //
195 // Launch the GUI for the root macros
196 //
197 if (!gROOT->IsBatch()) {
198 TMVA::TMVAGui(outfileName);
199 }
200
201 return 0;
202}
203
204//
205// This is used if the macro is compiled. If run through ROOT with
206// `root -l -b -q MACRO.C` or similar it is unused.
207//
208int main(int argc, char **argv)
209{
210 TMVACrossValidationRegression();
211}
int main()
Definition Prototype.cxx:12
unsigned int UInt_t
Definition RtypesCore.h:46
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
#define gROOT
Definition TROOT.h:406
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
R__EXTERN TSystem * gSystem
Definition TSystem.h:555
A specialized string object used for TTree selections.
Definition TCut.h:25
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4067
static Bool_t SetCacheFileDir(std::string_view cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Sets the directory where to locally stage/cache remote files.
Definition TFile.cxx:4603
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:928
Class to perform cross validation, splitting the dataloader into folds.
void AddRegressionTree(TTree *tree, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition DataLoader.h:103
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
static Tools & Instance()
Definition Tools.cxx:71
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:376
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1296
A TTree represents a columnar dataset.
Definition TTree.h:79
void TMVAGui(const char *fName="TMVA.root", TString dataset="")