Logo ROOT  
Reference Guide
TMVACrossValidationApplication.C File Reference

Detailed Description

View in nbviewer Open in SWAN This macro provides an example of how to use TMVA for k-folds cross evaluation in application.

This requires that CrossValidation was run with a deterministic split, such as "...:splitExpr=int([eventID])%int([numFolds]):...".

  • Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
  • Package : TMVA
  • Root Macro: TMVACrossValidationApplication
: Booking "BDTG" of type "CrossValidation" from dataset/weights/TMVACrossValidation_BDTG.weights.xml.
: Reading weight file: dataset/weights/TMVACrossValidation_BDTG.weights.xml
<HEADER> DataSetInfo : [Default] : Added class "Signal"
<HEADER> DataSetInfo : [Default] : Added class "Background"
: Reading weightfile: dataset/weights/TMVACrossValidation_BDTG_fold1.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidation_BDTG_fold1.weights.xml
: Reading weightfile: dataset/weights/TMVACrossValidation_BDTG_fold2.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidation_BDTG_fold2.weights.xml
: Booked classifier "BDTG" of type: "CrossValidation"
: Booking "Fisher" of type "CrossValidation" from dataset/weights/TMVACrossValidation_Fisher.weights.xml.
: Reading weight file: dataset/weights/TMVACrossValidation_Fisher.weights.xml
: Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
: Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml
: Booked classifier "Fisher" of type: "CrossValidation"
(int) 0
#include <cstdlib>
#include <iostream>
#include <map>
#include <string>
#include "TChain.h"
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TObjString.h"
#include "TSystem.h"
#include "TROOT.h"
#include "TMVA/Factory.h"
#include "TMVA/Tools.h"
#include "TMVA/TMVAGui.h"
// Helper function to load data into TTrees.
TTree *fillTree(TTree * tree, Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
{
TRandom3 rng(seed);
Float_t x = 0;
Float_t y = 0;
Int_t eventID = 0;
tree->SetBranchAddress("x", &x);
tree->SetBranchAddress("y", &y);
tree->SetBranchAddress("eventID", &eventID);
for (Int_t n = 0; n < nPoints; ++n) {
x = rng.Gaus(offset, scale);
y = rng.Gaus(offset, scale);
// For our simple example it is enough that the id's are uniformly
// distributed and independent of the data.
++eventID;
tree->Fill();
}
// Important: Disconnects the tree from the memory locations of x and y.
tree->ResetBranchAddresses();
return tree;
}
int TMVACrossValidationApplication()
{
// This loads the library
// Set up the TMVA::Reader
TMVA::Reader *reader = new TMVA::Reader("!Color:!Silent:!V");
Int_t eventID;
reader->AddVariable("x", &x);
reader->AddVariable("y", &y);
reader->AddSpectator("eventID", &eventID);
// Book the serialised methods
TString jobname("TMVACrossValidation");
{
TString methodName = "BDTG";
TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
if (weightfileExists) {
reader->BookMVA(methodName, weightfile);
} else {
std::cout << "Weightfile for method " << methodName << " not found."
" Did you run TMVACrossValidation with a specified"
" splitExpr?" << std::endl;
exit(0);
}
}
{
TString methodName = "Fisher";
TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
if (weightfileExists) {
reader->BookMVA(methodName, weightfile);
} else {
std::cout << "Weightfile for method " << methodName << " not found."
" Did you run TMVACrossValidation with a specified"
" splitExpr?" << std::endl;
exit(0);
}
}
// Load data
TTree *tree = new TTree();
tree->Branch("x", &x, "x/F");
tree->Branch("y", &y, "y/F");
tree->Branch("eventID", &eventID, "eventID/I");
fillTree(tree, 1000, 1.0, 1.0, 100);
fillTree(tree, 1000, -1.0, 1.0, 101);
tree->SetBranchAddress("x", &x);
tree->SetBranchAddress("y", &y);
tree->SetBranchAddress("eventID", &eventID);
// Prepare histograms
Int_t nbin = 100;
TH1F histBDTG{"BDTG", "BDTG", nbin, -1, 1};
TH1F histFisher{"Fisher", "Fisher", nbin, -1, 1};
// Evaluate classifiers
for (Long64_t ievt = 0; ievt < tree->GetEntries(); ievt++) {
tree->GetEntry(ievt);
Double_t valBDTG = reader->EvaluateMVA("BDTG");
Double_t valFisher = reader->EvaluateMVA("Fisher");
histBDTG.Fill(valBDTG);
histFisher.Fill(valFisher);
}
tree->ResetBranchAddresses();
delete tree;
if (!gROOT->IsBatch()) {
auto c = new TCanvas();
c->Divide(2,1);
c->cd(1);
histBDTG.DrawClone();
c->cd(2);
histFisher.DrawClone();
}
else
{ // Write histograms to output file
TFile *target = new TFile("TMVACrossEvaluationApp.root", "RECREATE");
histBDTG.Write();
histFisher.Write();
target->Close();
delete target;
}
delete reader;
return 0;
}
//
// This is used if the macro is compiled. If run through ROOT with
// `root -l -b -q MACRO.C` or similar it is unused.
//
int main(int argc, char **argv)
{
TMVACrossValidationApplication();
}
Author
Kim Albertsson (adapted from code originally by Andreas Hoecker)

Definition in file TMVACrossValidationApplication.C.

c
#define c(i)
Definition: RSha256.hxx:119
n
const Int_t n
Definition: legend1.C:16
TRandom::Gaus
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition: TRandom.cxx:263
TMVA::Reader::AddVariable
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:303
tree
Definition: tree.py:1
TObjString.h
Long64_t
long long Long64_t
Definition: RtypesCore.h:73
TTree
Definition: TTree.h:79
DataLoader.h
Float_t
float Float_t
Definition: RtypesCore.h:57
Int_t
int Int_t
Definition: RtypesCore.h:45
x
Double_t x[n]
Definition: legend1.C:17
TMVAGui.h
TTree.h
TString
Definition: TString.h:136
TSystem::AccessPathName
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1294
Bool_t
bool Bool_t
Definition: RtypesCore.h:63
TString.h
TFile.h
TROOT.h
TChain.h
TSystem.h
TRandom3
Definition: TRandom3.h:27
TMVA::Reader::BookMVA
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:368
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
main
int main(int argc, char **argv)
Definition: histspeedtest.cxx:751
UInt_t
unsigned int UInt_t
Definition: RtypesCore.h:46
y
Double_t y[n]
Definition: legend1.C:17
TFile
Definition: TFile.h:54
gSystem
R__EXTERN TSystem * gSystem
Definition: TSystem.h:559
TMVA::Tools::Instance
static Tools & Instance()
Definition: Tools.cxx:75
TMVA::Reader::AddSpectator
void AddSpectator(const TString &expression, Float_t *)
Add a float spectator or expression to the reader.
Definition: Reader.cxx:321
Double_t
double Double_t
Definition: RtypesCore.h:59
TCanvas
Definition: TCanvas.h:23
TMVA::Reader::EvaluateMVA
Double_t EvaluateMVA(const std::vector< Float_t > &, const TString &methodTag, Double_t aux=0)
Evaluate a std::vector<float> of input data for a given method The parameter aux is obligatory for th...
Definition: Reader.cxx:468
TH1F
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:572
TFile::Close
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:876
Factory.h
Tools.h
TMVA::Reader
Definition: Reader.h:92
gROOT
#define gROOT
Definition: TROOT.h:406