Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVACrossValidationRegression.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook -nodraw
4/// This macro provides an example of how to use TMVA for k-folds cross
5/// evaluation.
6///
7/// As input data is used a toy-MC sample consisting of two gaussian
8/// distributions.
9///
10/// The output file "TMVARegCv.root" can be analysed with the use of dedicated
11/// macros (simply say: root -l <macro.C>), which can be conveniently
12/// invoked through a GUI that will appear at the end of the run of this macro.
13/// Launch the GUI via the command:
14///
15/// ```
16/// root -l -e 'TMVA::TMVAGui("TMVARegCv.root")'
17/// ```
18///
19/// ## Cross Evaluation
20/// Cross evaluation is a special case of k-folds cross validation where the
21/// splitting into k folds is computed deterministically. This ensures that the
22/// a given event will always end up in the same fold.
23///
24/// In addition all resulting classifiers are saved and can be applied to new
25/// data using `MethodCrossValidation`. One requirement for this to work is a
26/// splitting function that is evaluated for each event to determine into what
27/// fold it goes (for training/evaluation) or to what classifier (for
28/// application).
29///
30/// ## Split Expression
31/// Cross evaluation uses a deterministic split to partition the data into
32/// folds called the split expression. The expression can be any valid
33/// `TFormula` as long as all parts used are defined.
34///
35/// For each event the split expression is evaluated to a number and the event
36/// is put in the fold corresponding to that number.
37///
38/// It is recommended to always use `%int([NumFolds])` at the end of the
39/// expression.
40///
41/// The split expression has access to all spectators and variables defined in
42/// the dataloader. Additionally, the number of folds in the split can be
43/// accessed with `NumFolds` (or `numFolds`).
44///
45/// ### Example
46/// ```
47/// "int(fabs([eventID]))%int([NumFolds])"
48/// ```
49///
50/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51/// - Package : TMVA
52/// - Root Macro: TMVACrossValidationRegression
53///
54/// \macro_output
55/// \macro_code
56/// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57
58#include <cstdlib>
59#include <iostream>
60#include <map>
61#include <string>
62
63#include "TChain.h"
64#include "TFile.h"
65#include "TTree.h"
66#include "TString.h"
67#include "TObjString.h"
68#include "TSystem.h"
69#include "TROOT.h"
70
71#include "TMVA/Factory.h"
72#include "TMVA/DataLoader.h"
73#include "TMVA/Tools.h"
74#include "TMVA/TMVAGui.h"
76
78 TFile *input(nullptr);
79
81 input = TFile::Open(fname); // check if file in local directory exists
82 }
83
84 if (!input) {
85 std::cout << "ERROR: could not open data file " << fname << std::endl;
86 exit(1);
87 }
88
89 return input;
90}
91
93{
94 // This loads the library
96
97 // --------------------------------------------------------------------------
98
99 // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
100 TString outfileName("TMVARegCv.root");
101 TFile * outputFile = TFile::Open(outfileName, "RECREATE");
102
103 TString infileName = gROOT->GetTutorialDir() + "/machine_learning/data/tmva_reg_example.root";
105
106 TMVA::DataLoader *dataloader=new TMVA::DataLoader("datasetcvreg");
107
108 dataloader->AddVariable("var1", "Variable 1", "units", 'F');
109 dataloader->AddVariable("var2", "Variable 2", "units", 'F');
110
111 // Add the variable carrying the regression target
112 dataloader->AddTarget("fvalue");
113
114 TTree * regTree = (TTree*)inputFile->Get("TreeR");
115 dataloader->AddRegressionTree(regTree, 1.0);
116
117 // Individual events can be weighted
118 // dataloader->SetWeightExpression("weight", "Regression");
119
120 std::cout << "--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl;
121
122 // Bypasses the normal splitting mechanism, CV uses a new system for this.
123 // Unfortunately the old system is unhappy if we leave the test set empty so
124 // we ensure that there is at least one event by placing the first event in
125 // it.
126 // You can with the selection cut place a global cut on the defined
127 // variables. Only events passing the cut will be using in training/testing.
128 // Example: `TCut selectionCut = "var1 < 1";`
129 TCut selectionCut = "";
130 dataloader->PrepareTrainingAndTestTree(selectionCut, "nTest_Regression=1"
131 ":SplitMode=Block"
132 ":NormMode=NumEvents"
133 ":!V");
134
135 // --------------------------------------------------------------------------
136
137 //
138 // This sets up a CrossValidation class (which wraps a TMVA::Factory
139 // internally) for 2-fold cross validation. The data will be split into the
140 // two folds randomly if `splitExpr` is `""`.
141 //
142 // One can also give a deterministic split using spectator variables. An
143 // example would be e.g. `"int(fabs([spec1]))%int([NumFolds])"`.
144 //
145 UInt_t numFolds = 2;
146 TString analysisType = "Regression";
147 TString splitExpr = "";
148
149 TString cvOptions = Form("!V"
150 ":!Silent"
151 ":ModelPersistence"
152 ":!FoldFileOutput"
153 ":AnalysisType=%s"
154 ":NumFolds=%i"
155 ":SplitExpr=%s",
156 analysisType.Data(), numFolds, splitExpr.Data());
157
158 TMVA::CrossValidation cv{"TMVACrossValidationRegression", dataloader, outputFile, cvOptions};
159
160 // --------------------------------------------------------------------------
161
162 //
163 // Books a method to use for evaluation
164 //
165 cv.BookMethod(TMVA::Types::kBDT, "BDTG",
166 "!H:!V:NTrees=500:BoostType=Grad:Shrinkage=0.1:"
167 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3");
168
169 // --------------------------------------------------------------------------
170
171 //
172 // Train, test and evaluate the booked methods.
173 // Evaluates the booked methods once for each fold and aggregates the result
174 // in the specified output file.
175 //
176 cv.Evaluate();
177
178 // --------------------------------------------------------------------------
179
180 //
181 // Save the output
182 //
183 outputFile->Close();
184
185 std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
186 std::cout << "==> TMVACrossValidationRegression is done!" << std::endl;
187
188 // --------------------------------------------------------------------------
189
190 //
191 // Launch the GUI for the root macros
192 //
193 if (!gROOT->IsBatch()) {
195 }
196
197 return 0;
198}
199
200//
201// This is used if the macro is compiled. If run through ROOT with
202// `root -l -b -q MACRO.C` or similar it is unused.
203//
204int main(int argc, char **argv)
205{
207}
int main()
Definition Prototype.cxx:12
unsigned int UInt_t
Definition RtypesCore.h:46
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
#define gROOT
Definition TROOT.h:406
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
A specialized string object used for TTree selections.
Definition TCut.h:25
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4130
Class to perform cross validation, splitting the dataloader into folds.
static Tools & Instance()
Definition Tools.cxx:71
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:376
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1296
A TTree represents a columnar dataset.
Definition TTree.h:79
void TMVAGui(const char *fName="TMVA.root", TString dataset="")