Logo ROOT   6.12/07
Reference Guide
TMVACrossValidation.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This example explains how to use the cross-validation feature of TMVA. It is
5 /// validated the Fisher algorithm with a 5-fold cross-validation.
6 /// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis
7 /// - Package : TMVA
8 /// - Exectuable: TMVACrossValidation
9 ///
10 /// \macro_output
11 /// \macro_code
12 /// \author Stefan Wunsch
13 
14 #include "TFile.h"
15 #include "TTree.h"
16 #include "TString.h"
17 #include "TSystem.h"
18 
19 #include "TMVA/DataLoader.h"
20 #include "TMVA/CrossValidation.h"
21 #include "TMVA/Tools.h"
22 
23 void TMVACrossValidation()
24 {
25  // This loads the library
27 
28  // Load data
29  TFile *input(0);
30  TString fname = "./tmva_class_example.root";
31  if (!gSystem->AccessPathName( fname )) {
32  input = TFile::Open( fname ); // check if file in local directory exists
33  }
34  else {
36  input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
37  }
38  if (!input) {
39  std::cout << "ERROR: could not open data file" << std::endl;
40  exit(1);
41  }
42 
43  TTree* signalTree = (TTree*)input->Get("TreeS");
44  TTree* background = (TTree*)input->Get("TreeB");
45 
46  // Setup dataloader
48 
49  dataloader->AddSignalTree(signalTree);
50  dataloader->AddBackgroundTree(background);
51 
52  dataloader->AddVariable("var1");
53  dataloader->AddVariable("var2");
54  dataloader->AddVariable("var3");
55  dataloader->AddVariable("var4");
56 
57  dataloader->PrepareTrainingAndTestTree("", "SplitMode=Random:NormMode=NumEvents:!V");
58 
59  // Setup cross-validation with Fisher method
60  TMVA::CrossValidation cv(dataloader);
61  cv.BookMethod(TMVA::Types::kFisher, "Fisher", "!H:!V:Fisher");
62 
63  // Run cross-validation and print results
64  cv.Evaluate();
65  auto results = cv.GetResults();
66  for (auto r : results)
67  r.Print();
68 }
69 
70 int main()
71 {
72  TMVACrossValidation();
73 }
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:408
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1276
static Tools & Instance()
Definition: Tools.cxx:75
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition: TFile.h:303
Basic string class.
Definition: TString.h:125
virtual void Print(Option_t *option="") const
This method must be overridden when a class wants to print itself.
Definition: TObject.cxx:550
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3950
int main(int argc, char **argv)
ROOT::R::TRInterface & r
Definition: Object.C:4
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:629
Class to perform cross validation, splitting the dataloader into folds.
A TTree object has a header with a name and a title.
Definition: TTree.h:70
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:377