Logo ROOT  
Reference Guide
Loading...
Searching...
No Matches
TMVACrossValidationRegression.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook -nodraw
4/// This macro provides an example of how to use TMVA for k-folds cross
5/// evaluation.
6///
7/// As input data is used a toy-MC sample consisting of two gaussian
8/// distributions.
9///
10/// The output file "TMVARegCv.root" can be analysed with the use of dedicated
11/// macros (simply say: root -l <macro.C>), which can be conveniently
12/// invoked through a GUI that will appear at the end of the run of this macro.
13/// You can also launch the GUI in another ROOT session via the command:
14///
15/// ```
16/// root -l -e 'TMVA::TMVAGui("TMVARegCv.root")'
17/// ```
18///
19/// ## Cross Evaluation
20/// Cross evaluation is a special case of k-folds cross validation where the
21/// splitting into k folds is computed deterministically. This ensures that the
22/// a given event will always end up in the same fold.
23///
24/// In addition all resulting classifiers are saved and can be applied to new
25/// data using `MethodCrossValidation`. One requirement for this to work is a
26/// splitting function that is evaluated for each event to determine into what
27/// fold it goes (for training/evaluation) or to what classifier (for
28/// application).
29///
30/// ## Split Expression
31/// Cross evaluation uses a deterministic split to partition the data into
32/// folds called the split expression. The expression can be any valid
33/// `TFormula` as long as all parts used are defined.
34///
35/// For each event the split expression is evaluated to a number and the event
36/// is put in the fold corresponding to that number.
37///
38/// It is recommended to always use `%int([NumFolds])` at the end of the
39/// expression.
40///
41/// The split expression has access to all spectators and variables defined in
42/// the dataloader. Additionally, the number of folds in the split can be
43/// accessed with `NumFolds` (or `numFolds`).
44///
45/// ### Example
46/// ```
47/// "int(fabs([eventID]))%int([NumFolds])"
48/// ```
49///
50/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51/// - Package : TMVA
52/// - Root Macro: TMVACrossValidationRegression
53///
54/// \macro_code
55/// \macro_output
56/// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57
58#include <cstdlib>
59#include <iostream>
60#include <map>
61#include <string>
62
63#include "TChain.h"
64#include "TFile.h"
65#include "TTree.h"
66#include "TString.h"
67#include "TObjString.h"
68#include "TSystem.h"
69#include "TROOT.h"
70
71#include "TMVA/Factory.h"
72#include "TMVA/DataLoader.h"
73#include "TMVA/Tools.h"
74#include "TMVA/TMVAGui.h"
76
77TFile * getDataFile(TString fname) {
78 TFile *input(nullptr);
79
80 if (!gSystem->AccessPathName(fname)) {
81 input = TFile::Open(fname); // check if file in local directory exists
82 }
83
84 if (!input) {
85 std::cout << "ERROR: could not open data file " << fname << std::endl;
86 exit(1);
87 }
88
89 return input;
90}
91
92int TMVACrossValidationRegression()
93{
94 // This loads the library
96
97 // --------------------------------------------------------------------------
98
99 // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
100 TString outfileName("TMVARegCv.root");
101 TFile * outputFile = TFile::Open(outfileName, "RECREATE");
102
103 TString infileName = gROOT->GetTutorialDir() + "/machine_learning/data/tmva_reg_example.root";
104 TFile * inputFile = getDataFile(infileName);
105
106 TMVA::DataLoader *dataloader=new TMVA::DataLoader("datasetcvreg");
107
108 dataloader->AddVariable("var1", "Variable 1", "units", 'F');
109 dataloader->AddVariable("var2", "Variable 2", "units", 'F');
110
111 // Add the variable carrying the regression target
112 dataloader->AddTarget("fvalue");
113
114 TTree * regTree = (TTree*)inputFile->Get("TreeR");
115 dataloader->AddRegressionTree(regTree, 1.0);
116
117 // Individual events can be weighted
118 // dataloader->SetWeightExpression("weight", "Regression");
119
120 std::cout << "--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl;
121
122 // Bypasses the normal splitting mechanism, CV uses a new system for this.
123 // Unfortunately the old system is unhappy if we leave the test set empty so
124 // we ensure that there is at least one event by placing the first event in
125 // it.
126 // You can with the selection cut place a global cut on the defined
127 // variables. Only events passing the cut will be using in training/testing.
128 // Example: `TCut selectionCut = "var1 < 1";`
129 TCut selectionCut = "";
130 dataloader->PrepareTrainingAndTestTree(selectionCut, "nTest_Regression=1"
131 ":SplitMode=Block"
132 ":NormMode=NumEvents"
133 ":!V");
134
135 // --------------------------------------------------------------------------
136
137 //
138 // This sets up a CrossValidation class (which wraps a TMVA::Factory
139 // internally) for 2-fold cross validation. The data will be split into the
140 // two folds randomly if `splitExpr` is `""`.
141 //
142 // One can also give a deterministic split using spectator variables. An
143 // example would be e.g. `"int(fabs([spec1]))%int([NumFolds])"`.
144 //
145 UInt_t numFolds = 2;
146 TString analysisType = "Regression";
147 TString splitExpr = "";
148
149 TString cvOptions = Form("!V"
150 ":!Silent"
151 ":ModelPersistence"
152 ":!FoldFileOutput"
153 ":AnalysisType=%s"
154 ":NumFolds=%i"
155 ":SplitExpr=%s",
156 analysisType.Data(), numFolds, splitExpr.Data());
157
158 TMVA::CrossValidation cv{"TMVACrossValidationRegression", dataloader, outputFile, cvOptions};
159
160 // --------------------------------------------------------------------------
161
162 //
163 // Books a method to use for evaluation
164 //
165 cv.BookMethod(TMVA::Types::kBDT, "BDTG",
166 "!H:!V:NTrees=500:BoostType=Grad:Shrinkage=0.1:"
167 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3");
168
169 // --------------------------------------------------------------------------
170
171 //
172 // Train, test and evaluate the booked methods.
173 // Evaluates the booked methods once for each fold and aggregates the result
174 // in the specified output file.
175 //
176 cv.Evaluate();
177
178 // --------------------------------------------------------------------------
179
180 //
181 // Save the output
182 //
183 outputFile->Close();
184
185 std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
186 std::cout << "==> TMVACrossValidationRegression is done!" << std::endl;
187
188 // --------------------------------------------------------------------------
189
190 //
191 // Launch the GUI for the root macros
192 //
193 if (!gROOT->IsBatch()) {
194 TMVA::TMVAGui(outfileName);
195 }
196
197 return 0;
198}
199
200//
201// This is used if the macro is compiled. If run through ROOT with
202// `root -b -q MACRO.C` or similar it is unused.
203//
204int main(int argc, char **argv)
205{
206 TMVACrossValidationRegression();
207}
int main()
Definition Prototype.cxx:12
unsigned int UInt_t
Unsigned integer 4 bytes (unsigned int).
Definition RtypesCore.h:60
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
#define gROOT
Definition TROOT.h:417
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2496
externTSystem * gSystem
Definition TSystem.h:582
A specialized string object used for TTree selections.
Definition TCut.h:25
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Definition TFile.h:130
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
void Close(Option_t *option="") override
Delete all objects from memory and directory structure itself.
Class to perform cross validation, splitting the dataloader into folds.
void Evaluate() override
Does training, test set evaluation and performance evaluation of using cross-evalution.
void AddRegressionTree(TTree *tree, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition DataLoader.h:103
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition Envelope.cxx:163
static Tools & Instance()
Definition Tools.cxx:72
const char * GetName() const override
Returns name of object.
Definition TNamed.h:49
Basic string class.
Definition TString.h:138
const char * Data() const
Definition TString.h:384
A TTree represents a columnar dataset.
Definition TTree.h:89
void TMVAGui(const char *fName="TMVA.root", TString dataset="")