Loading [MathJax]/extensions/tex2jax.js
Logo ROOT  
Reference Guide
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
TMVACrossValidationApplication.C File Reference

Detailed Description

View in nbviewer Open in SWAN This macro provides an example of how to use TMVA for k-folds cross evaluation in application.

This requires that CrossValidation was run with a deterministic split, such as "...:splitExpr=int([eventID])%int([numFolds]):...".

  • Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
  • Package : TMVA
  • Root Macro: TMVACrossValidationApplication
: Booking "BDTG" of type "CrossValidation" from dataset/weights/TMVACrossValidation_BDTG.weights.xml.
: Reading weight file: dataset/weights/TMVACrossValidation_BDTG.weights.xml
<HEADER> DataSetInfo : [Default] : Added class "Signal"
<HEADER> DataSetInfo : [Default] : Added class "Background"
: Reading weightfile: dataset/weights/TMVACrossValidation_BDTG_fold1.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidation_BDTG_fold1.weights.xml
: Reading weightfile: dataset/weights/TMVACrossValidation_BDTG_fold2.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidation_BDTG_fold2.weights.xml
: Booked classifier "BDTG" of type: "CrossValidation"
: Booking "Fisher" of type "CrossValidation" from dataset/weights/TMVACrossValidation_Fisher.weights.xml.
: Reading weight file: dataset/weights/TMVACrossValidation_Fisher.weights.xml
: Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
: Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml
: Booked classifier "Fisher" of type: "CrossValidation"
(int) 0
#include <cstdlib>
#include <iostream>
#include <map>
#include <string>
#include "TChain.h"
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TObjString.h"
#include "TSystem.h"
#include "TROOT.h"
#include "TMVA/Factory.h"
#include "TMVA/Tools.h"
#include "TMVA/TMVAGui.h"
// Helper function to load data into TTrees.
TTree *fillTree(TTree * tree, Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
{
TRandom3 rng(seed);
Float_t x = 0;
Float_t y = 0;
Int_t eventID = 0;
tree->SetBranchAddress("x", &x);
tree->SetBranchAddress("y", &y);
tree->SetBranchAddress("eventID", &eventID);
for (Int_t n = 0; n < nPoints; ++n) {
x = rng.Gaus(offset, scale);
y = rng.Gaus(offset, scale);
// For our simple example it is enough that the id's are uniformly
// distributed and independent of the data.
++eventID;
tree->Fill();
}
// Important: Disconnects the tree from the memory locations of x and y.
tree->ResetBranchAddresses();
return tree;
}
int TMVACrossValidationApplication()
{
// This loads the library
// Set up the TMVA::Reader
TMVA::Reader *reader = new TMVA::Reader("!Color:!Silent:!V");
Int_t eventID;
reader->AddVariable("x", &x);
reader->AddVariable("y", &y);
reader->AddSpectator("eventID", &eventID);
// Book the serialised methods
TString jobname("TMVACrossValidation");
{
TString methodName = "BDTG";
TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
if (weightfileExists) {
reader->BookMVA(methodName, weightfile);
} else {
std::cout << "Weightfile for method " << methodName << " not found."
" Did you run TMVACrossValidation with a specified"
" splitExpr?" << std::endl;
exit(0);
}
}
{
TString methodName = "Fisher";
TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
if (weightfileExists) {
reader->BookMVA(methodName, weightfile);
} else {
std::cout << "Weightfile for method " << methodName << " not found."
" Did you run TMVACrossValidation with a specified"
" splitExpr?" << std::endl;
exit(0);
}
}
// Load data
TTree *tree = new TTree();
tree->Branch("x", &x, "x/F");
tree->Branch("y", &y, "y/F");
tree->Branch("eventID", &eventID, "eventID/I");
fillTree(tree, 1000, 1.0, 1.0, 100);
fillTree(tree, 1000, -1.0, 1.0, 101);
tree->SetBranchAddress("x", &x);
tree->SetBranchAddress("y", &y);
tree->SetBranchAddress("eventID", &eventID);
// Prepare histograms
Int_t nbin = 100;
TH1F histBDTG{"BDTG", "BDTG", nbin, -1, 1};
TH1F histFisher{"Fisher", "Fisher", nbin, -1, 1};
// Evaluate classifiers
for (Long64_t ievt = 0; ievt < tree->GetEntries(); ievt++) {
tree->GetEntry(ievt);
Double_t valBDTG = reader->EvaluateMVA("BDTG");
Double_t valFisher = reader->EvaluateMVA("Fisher");
histBDTG.Fill(valBDTG);
histFisher.Fill(valFisher);
}
tree->ResetBranchAddresses();
delete tree;
if (!gROOT->IsBatch()) {
auto c = new TCanvas();
c->Divide(2,1);
c->cd(1);
histBDTG.DrawClone();
c->cd(2);
histFisher.DrawClone();
}
else
{ // Write histograms to output file
TFile *target = new TFile("TMVACrossEvaluationApp.root", "RECREATE");
histBDTG.Write();
histFisher.Write();
target->Close();
delete target;
}
delete reader;
return 0;
}
//
// This is used if the macro is compiled. If run through ROOT with
// `root -l -b -q MACRO.C` or similar it is unused.
//
int main(int argc, char **argv)
{
TMVACrossValidationApplication();
}
#define c(i)
Definition: RSha256.hxx:101
int Int_t
Definition: RtypesCore.h:43
unsigned int UInt_t
Definition: RtypesCore.h:44
const Bool_t kFALSE
Definition: RtypesCore.h:90
bool Bool_t
Definition: RtypesCore.h:61
double Double_t
Definition: RtypesCore.h:57
long long Long64_t
Definition: RtypesCore.h:71
float Float_t
Definition: RtypesCore.h:55
#define gROOT
Definition: TROOT.h:406
R__EXTERN TSystem * gSystem
Definition: TSystem.h:556
The Canvas class.
Definition: TCanvas.h:27
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:53
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:873
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
Double_t EvaluateMVA(const std::vector< Float_t > &, const TString &methodTag, Double_t aux=0)
Evaluate a std::vector<float> of input data for a given method The parameter aux is obligatory for th...
Definition: Reader.cxx:473
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:373
void AddSpectator(const TString &expression, Float_t *)
Add a float spectator or expression to the reader.
Definition: Reader.cxx:326
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:308
static Tools & Instance()
Definition: Tools.cxx:74
Random number generator class based on M.
Definition: TRandom3.h:27
Basic string class.
Definition: TString.h:131
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1291
A TTree represents a columnar dataset.
Definition: TTree.h:78
int main(int argc, char **argv)
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
Definition: tree.py:1
Author
Kim Albertsson (adapted from code originally by Andreas Hoecker)

Definition in file TMVACrossValidationApplication.C.