Logo ROOT  
Reference Guide
TMVACrossValidationRegression.C File Reference

Detailed Description

View in nbviewer Open in SWAN This macro provides an example of how to use TMVA for k-folds cross evaluation.

As input data is used a toy-MC sample consisting of two guassian distributions.

The output file "TMVA.root" can be analysed with the use of dedicated macros (simply say: root -l <macro.C>), which can be conveniently invoked through a GUI that will appear at the end of the run of this macro. Launch the GUI via the command:

root -l -e 'TMVA::TMVAGui("TMVA.root")'

Cross Evaluation

Cross evaluation is a special case of k-folds cross validation where the splitting into k folds is computed deterministically. This ensures that the a given event will always end up in the same fold.

In addition all resulting classifiers are saved and can be applied to new data using MethodCrossValidation. One requirement for this to work is a splitting function that is evaluated for each event to determine into what fold it goes (for training/evaluation) or to what classifier (for application).

Split Expression

Cross evaluation uses a deterministic split to partition the data into folds called the split expression. The expression can be any valid TFormula as long as all parts used are defined.

For each event the split expression is evaluated to a number and the event is put in the fold corresponding to that number.

It is recommended to always use int([NumFolds]) at the end of the expression.

The split expression has access to all spectators and variables defined in the dataloader. Additionally, the number of folds in the split can be accessed with NumFolds (or numFolds).

Example

"int(fabs([eventID]))%int([NumFolds])"
  • Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
  • Package : TMVA
  • Root Macro: TMVACrossValidationRegression
DataSetInfo : [dataset] : Added class "Regression"
: Add Tree TreeR of type Regression with 10000 events
--- TMVACrossValidationRegression: Using input file: ./files/tmva_reg_example.root
: Dataset[dataset] : Class index : 0 name : Regression
<HEADER> Factory : You are running ROOT Version: 6.23/01, May 19, 2020
:
: _/_/_/_/_/ _| _| _| _| _|_|
: _/ _|_| _|_| _| _| _| _|
: _/ _| _| _| _| _| _|_|_|_|
: _/ _| _| _| _| _| _|
: _/ _| _| _| _| _|
:
: ___________TMVA Version 4.2.1, Feb 5, 2015
:
: Building event vectors for type 2 Regression
: Dataset[dataset] : create input formulas for tree TreeR
<HEADER> DataSetFactory : [dataset] : Number of events in input trees
:
: Number of training and testing events
: ---------------------------------------------------------------------------
: Regression -- training events : 9999
: Regression -- testing events : 1
: Regression -- training and testing events: 10000
:
<HEADER> DataSetInfo : Correlation matrix (Regression):
: ------------------------
: var1 var2
: var1: +1.000 +0.002
: var2: +0.002 +1.000
: ------------------------
<HEADER> DataSetFactory : [dataset] :
:
:
:
: ========================================
: ========================================
:
<HEADER> Factory : Booking method: BDTG_fold1
:
: the option NegWeightTreatment=InverseBoostNegWeights does not exist for BoostType=Grad
: --> change to new default NegWeightTreatment=Pray
: Regression Loss Function: Huber
: Training 500 Decision Trees ... patience please
: Elapsed time for training with 4999 events: 1.3 sec
: Dataset[dataset] : Create results for training
: Dataset[dataset] : Evaluation of BDTG_fold1 on training sample
: Dataset[dataset] : Elapsed time for evaluation of 4999 events: 0.206 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
: Creating xml weight file: dataset/weights/TMVACrossValidationRegression_BDTG_fold1.weights.xml
<HEADER> Factory : Test all methods
<HEADER> Factory : Test method: BDTG_fold1 for Regression performance
:
: Dataset[dataset] : Create results for testing
: Dataset[dataset] : Evaluation of BDTG_fold1 on testing sample
: Dataset[dataset] : Elapsed time for evaluation of 5000 events: 0.206 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
<HEADER> Factory : Evaluate all methods
: Evaluate regression method: BDTG_fold1
: TestRegression (testing)
: Calculate regression for all events
: Elapsed time for evaluation of 5000 events: 0.205 sec
: TestRegression (training)
: Calculate regression for all events
: Elapsed time for evaluation of 4999 events: 0.204 sec
:
: Evaluation results ranked by smallest RMS on test sample:
: ("Bias" quotes the mean deviation of the regression from true target.
: "MutInf" is the "Mutual Information" between regression and target.
: Indicated by "_T" are the corresponding "truncated" quantities ob-
: tained when removing events deviating more than 2sigma from average.)
: --------------------------------------------------------------------------------------------------
: --------------------------------------------------------------------------------------------------
: dataset BDTG_fold1 : 0.133 0.0851 2.22 1.67 | 3.123 3.198
: --------------------------------------------------------------------------------------------------
:
: Evaluation results ranked by smallest RMS on training sample:
: (overtraining check)
: --------------------------------------------------------------------------------------------------
: DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T
: --------------------------------------------------------------------------------------------------
: dataset BDTG_fold1 : 0.0474 -0.00861 2.09 1.52 | 3.136 3.206
: --------------------------------------------------------------------------------------------------
:
<HEADER> Factory : Thank you for using TMVA!
: For citation information, please visit: http://tmva.sf.net/citeTMVA.html
<HEADER> Factory : Booking method: BDTG_fold2
:
: the option NegWeightTreatment=InverseBoostNegWeights does not exist for BoostType=Grad
: --> change to new default NegWeightTreatment=Pray
: Regression Loss Function: Huber
: Training 500 Decision Trees ... patience please
: Elapsed time for training with 5000 events: 1.31 sec
: Dataset[dataset] : Create results for training
: Dataset[dataset] : Evaluation of BDTG_fold2 on training sample
: Dataset[dataset] : Elapsed time for evaluation of 5000 events: 0.21 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
: Creating xml weight file: dataset/weights/TMVACrossValidationRegression_BDTG_fold2.weights.xml
<HEADER> Factory : Test all methods
<HEADER> Factory : Test method: BDTG_fold2 for Regression performance
:
: Dataset[dataset] : Create results for testing
: Dataset[dataset] : Evaluation of BDTG_fold2 on testing sample
: Dataset[dataset] : Elapsed time for evaluation of 4999 events: 0.209 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
<HEADER> Factory : Evaluate all methods
: Evaluate regression method: BDTG_fold2
: TestRegression (testing)
: Calculate regression for all events
: Elapsed time for evaluation of 4999 events: 0.207 sec
: TestRegression (training)
: Calculate regression for all events
: Elapsed time for evaluation of 5000 events: 0.208 sec
:
: Evaluation results ranked by smallest RMS on test sample:
: ("Bias" quotes the mean deviation of the regression from true target.
: "MutInf" is the "Mutual Information" between regression and target.
: Indicated by "_T" are the corresponding "truncated" quantities ob-
: tained when removing events deviating more than 2sigma from average.)
: --------------------------------------------------------------------------------------------------
: --------------------------------------------------------------------------------------------------
: dataset BDTG_fold2 : -0.0428 -0.0362 2.33 1.72 | 3.109 3.188
: --------------------------------------------------------------------------------------------------
:
: Evaluation results ranked by smallest RMS on training sample:
: (overtraining check)
: --------------------------------------------------------------------------------------------------
: DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T
: --------------------------------------------------------------------------------------------------
: dataset BDTG_fold2 : 0.00417 0.0137 2.05 1.51 | 3.145 3.215
: --------------------------------------------------------------------------------------------------
:
<HEADER> Factory : Thank you for using TMVA!
: For citation information, please visit: http://tmva.sf.net/citeTMVA.html
<HEADER> Factory : Booking method: BDTG
:
: Reading weightfile: dataset/weights/TMVACrossValidationRegression_BDTG_fold1.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidationRegression_BDTG_fold1.weights.xml
: Reading weightfile: dataset/weights/TMVACrossValidationRegression_BDTG_fold2.weights.xml
: Reading weight file: dataset/weights/TMVACrossValidationRegression_BDTG_fold2.weights.xml
:
:
: ========================================
: Folds processed for all methods, evaluating.
: ========================================
:
<HEADER> Factory : [dataset] : Create Transformation "I" with events from all classes.
:
<HEADER> : Transformation, Variable selection :
: Input : variable 'var1' <---> Output : variable 'var1'
: Input : variable 'var2' <---> Output : variable 'var2'
<HEADER> TFHandler_Factory : Variable Mean RMS [ Min Max ]
: -----------------------------------------------------------
: var1: 2.4948 1.4515 [ 0.00020069 5.0000 ]
: var2: 2.4837 1.4409 [ 0.00071490 5.0000 ]
: fvalue: 134.53 84.778 [ 1.6186 394.84 ]
: -----------------------------------------------------------
: Ranking input variables (method unspecific)...
<HEADER> IdTransformation : Ranking result (top variable is best ranked)
: --------------------------------------------
: Rank : Variable : |Correlation with target|
: --------------------------------------------
: 1 : var2 : 7.607e-01
: 2 : var1 : 5.995e-01
: --------------------------------------------
<HEADER> IdTransformation : Ranking result (top variable is best ranked)
: -------------------------------------
: Rank : Variable : Mutual information
: -------------------------------------
: 1 : var1 : 2.253e+00
: 2 : var2 : 2.100e+00
: -------------------------------------
<HEADER> IdTransformation : Ranking result (top variable is best ranked)
: ------------------------------------
: Rank : Variable : Correlation Ratio
: ------------------------------------
: 1 : var2 : 2.458e+00
: 2 : var1 : 2.336e+00
: ------------------------------------
<HEADER> IdTransformation : Ranking result (top variable is best ranked)
: ----------------------------------------
: Rank : Variable : Correlation Ratio (T)
: ----------------------------------------
: 1 : var1 : 5.362e-01
: 2 : var2 : 5.109e-01
: ----------------------------------------
: Elapsed time for training with 9999 events: 4.77e-06 sec
: Dataset[dataset] : Create results for training
: Dataset[dataset] : Evaluation of BDTG on training sample
: Dataset[dataset] : Elapsed time for evaluation of 9999 events: 0.364 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
: Creating xml weight file: dataset/weights/TMVACrossValidationRegression_BDTG.weights.xml
<HEADER> Factory : Test all methods
<HEADER> Factory : Test method: BDTG for Regression performance
:
: Dataset[dataset] : Create results for testing
: Dataset[dataset] : Evaluation of BDTG on testing sample
: Dataset[dataset] : Elapsed time for evaluation of 9999 events: 0.365 sec
: Create variable histograms
: Create regression target histograms
: Create regression average deviation
: Results created
<HEADER> Factory : Evaluate all methods
: Evaluate regression method: BDTG
: TestRegression (testing)
: Calculate regression for all events
: Elapsed time for evaluation of 9999 events: 0.364 sec
: TestRegression (training)
: Calculate regression for all events
: Elapsed time for evaluation of 9999 events: 0.365 sec
<HEADER> TFHandler_BDTG : Variable Mean RMS [ Min Max ]
: -----------------------------------------------------------
: var1: 2.4948 1.4515 [ 0.00020069 5.0000 ]
: var2: 2.4837 1.4409 [ 0.00071490 5.0000 ]
: fvalue: 134.53 84.778 [ 1.6186 394.84 ]
: -----------------------------------------------------------
:
: Evaluation results ranked by smallest RMS on test sample:
: ("Bias" quotes the mean deviation of the regression from true target.
: "MutInf" is the "Mutual Information" between regression and target.
: Indicated by "_T" are the corresponding "truncated" quantities ob-
: tained when removing events deviating more than 2sigma from average.)
: --------------------------------------------------------------------------------------------------
: --------------------------------------------------------------------------------------------------
: dataset BDTG : 0.0449 0.0259 2.28 1.70 | 3.108 3.190
: --------------------------------------------------------------------------------------------------
:
: Evaluation results ranked by smallest RMS on training sample:
: (overtraining check)
: --------------------------------------------------------------------------------------------------
: DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T
: --------------------------------------------------------------------------------------------------
: dataset BDTG : 0.0449 0.0259 2.28 1.70 | 3.108 3.190
: --------------------------------------------------------------------------------------------------
:
<HEADER> Dataset:dataset : Created tree 'TestTree' with 9999 events
:
<HEADER> Dataset:dataset : Created tree 'TrainTree' with 9999 events
:
<HEADER> Factory : Thank you for using TMVA!
: For citation information, please visit: http://tmva.sf.net/citeTMVA.html
: Evaluation done.
==> Wrote root file: TMVARegCv.root
==> TMVACrossValidationRegression is done!
(int) 0
#include <cstdlib>
#include <iostream>
#include <map>
#include <string>
#include "TChain.h"
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TObjString.h"
#include "TSystem.h"
#include "TROOT.h"
#include "TMVA/Factory.h"
#include "TMVA/Tools.h"
#include "TMVA/TMVAGui.h"
TFile * getDataFile(TString fname) {
TFile *input(0);
if (!gSystem->AccessPathName(fname)) {
input = TFile::Open(fname); // check if file in local directory exists
} else {
// if not: download from ROOT server
input = TFile::Open("http://root.cern.ch/files/tmva_reg_example.root", "CACHEREAD");
}
if (!input) {
std::cout << "ERROR: could not open data file " << fname << std::endl;
exit(1);
}
return input;
}
int TMVACrossValidationRegression()
{
// This loads the library
// --------------------------------------------------------------------------
// Create a ROOT output file where TMVA will store ntuples, histograms, etc.
TString outfileName("TMVARegCv.root");
TFile * outputFile = TFile::Open(outfileName, "RECREATE");
TString infileName("./files/tmva_reg_example.root");
TFile * inputFile = getDataFile(infileName);
TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
dataloader->AddVariable("var1", "Variable 1", "units", 'F');
dataloader->AddVariable("var2", "Variable 2", "units", 'F');
// Add the variable carrying the regression target
dataloader->AddTarget("fvalue");
TTree * regTree = (TTree*)inputFile->Get("TreeR");
dataloader->AddRegressionTree(regTree, 1.0);
// Individual events can be weighted
// dataloader->SetWeightExpression("weight", "Regression");
std::cout << "--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl;
// Bypasses the normal splitting mechanism, CV uses a new system for this.
// Unfortunately the old system is unhappy if we leave the test set empty so
// we ensure that there is at least one event by placing the first event in
// it.
// You can with the selection cut place a global cut on the defined
// variables. Only events passing the cut will be using in training/testing.
// Example: `TCut selectionCut = "var1 < 1";`
TCut selectionCut = "";
dataloader->PrepareTrainingAndTestTree(selectionCut, "nTest_Regression=1"
":SplitMode=Block"
":NormMode=NumEvents"
":!V");
// --------------------------------------------------------------------------
//
// This sets up a CrossValidation class (which wraps a TMVA::Factory
// internally) for 2-fold cross validation. The data will be split into the
// two folds randomly if `splitExpr` is `""`.
//
// One can also give a deterministic split using spectator variables. An
// example would be e.g. `"int(fabs([spec1]))%int([NumFolds])"`.
//
UInt_t numFolds = 2;
TString analysisType = "Regression";
TString splitExpr = "";
TString cvOptions = Form("!V"
":!Silent"
":ModelPersistence"
":!FoldFileOutput"
":AnalysisType=%s"
":NumFolds=%i"
":SplitExpr=%s",
analysisType.Data(), numFolds, splitExpr.Data());
TMVA::CrossValidation cv{"TMVACrossValidationRegression", dataloader, outputFile, cvOptions};
// --------------------------------------------------------------------------
//
// Books a method to use for evaluation
//
cv.BookMethod(TMVA::Types::kBDT, "BDTG",
"!H:!V:NTrees=500:BoostType=Grad:Shrinkage=0.1:"
"UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3");
// --------------------------------------------------------------------------
//
// Train, test and evaluate the booked methods.
// Evaluates the booked methods once for each fold and aggregates the result
// in the specified output file.
//
cv.Evaluate();
// --------------------------------------------------------------------------
//
// Save the output
//
outputFile->Close();
std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
std::cout << "==> TMVACrossValidationRegression is done!" << std::endl;
// --------------------------------------------------------------------------
//
// Launch the GUI for the root macros
//
if (!gROOT->IsBatch()) {
TMVA::TMVAGui(outfileName);
}
return 0;
}
//
// This is used if the macro is compiled. If run through ROOT with
// `root -l -b -q MACRO.C` or similar it is unused.
//
int main(int argc, char **argv)
{
TMVACrossValidationRegression();
}
Author
Kim Albertsson (adapted from code originally by Andreas Hoecker)

Definition in file TMVACrossValidationRegression.C.

l
auto * l
Definition: textangle.C:4
TCut
Definition: TCut.h:25
CrossValidation.h
e
#define e(i)
Definition: RSha256.hxx:121
TMVA::DataLoader::PrepareTrainingAndTestTree
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:631
TFile::SetCacheFileDir
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition: TFile.h:324
TString::Data
const char * Data() const
Definition: TString.h:369
Form
char * Form(const char *fmt,...)
TObjString.h
TMVA::Types::kBDT
@ kBDT
Definition: Types.h:111
TTree
Definition: TTree.h:79
DataLoader.h
TFile::Open
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3946
TMVAGui.h
TTree.h
TString
Definition: TString.h:136
TSystem::AccessPathName
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1294
TString.h
TFile.h
TROOT.h
TMVA::DataLoader::AddTarget
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:511
TDirectoryFile::Get
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Definition: TDirectoryFile.cxx:909
TChain.h
TSystem.h
main
int main(int argc, char **argv)
Definition: histspeedtest.cxx:751
UInt_t
unsigned int UInt_t
Definition: RtypesCore.h:46
TFile
Definition: TFile.h:54
gSystem
R__EXTERN TSystem * gSystem
Definition: TSystem.h:559
TMVA::Tools::Instance
static Tools & Instance()
Definition: Tools.cxx:75
TMVA::DataLoader::AddRegressionTree
void AddRegressionTree(TTree *tree, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.h:103
TFile::Close
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:876
Factory.h
TMVA::TMVAGui
void TMVAGui(const char *fName="TMVA.root", TString dataset="")
Tools.h
TNamed::GetName
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:53
TMVA::DataLoader::AddVariable
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:484
TMVA::CrossValidation
Use html for explicit line breaking Markdown links? class reference?
Definition: CrossValidation.h:124
gROOT
#define gROOT
Definition: TROOT.h:406
TMVA::DataLoader
Definition: DataLoader.h:50