Logo ROOT  
Reference Guide
TMVAClassificationCategory.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides examples for the training and testing of the
5/// TMVA classifiers in categorisation mode.
6/// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis
7/// - Package : TMVA
8/// - Root Macro: TMVAClassificationCategory
9///
10/// As input data is used a toy-MC sample consisting of four Gaussian-distributed
11/// and linearly correlated input variables with category (eta) dependent
12/// properties.
13///
14/// For this example, only Fisher and Likelihood are used. Run via:
15///
16/// root -l TMVAClassificationCategory.C
17///
18/// The output file "TMVA.root" can be analysed with the use of dedicated
19/// macros (simply say: root -l <macro.C>), which can be conveniently
20/// invoked through a GUI that will appear at the end of the run of this macro.
21///
22/// \macro_output
23/// \macro_code
24/// \author Andreas Hoecker
25
26
27#include <cstdlib>
28#include <iostream>
29#include <map>
30#include <string>
31
32#include "TChain.h"
33#include "TFile.h"
34#include "TTree.h"
35#include "TString.h"
36#include "TObjString.h"
37#include "TSystem.h"
38#include "TROOT.h"
39
40#include "TMVA/MethodCategory.h"
41#include "TMVA/Factory.h"
42#include "TMVA/DataLoader.h"
43#include "TMVA/Tools.h"
44#include "TMVA/TMVAGui.h"
45
46
47// two types of category methods are implemented
48Bool_t UseOffsetMethod = kTRUE;
49
50void TMVAClassificationCategory()
51{
52 //---------------------------------------------------------------
53 // Example for usage of different event categories with classifiers
54
55 std::cout << std::endl << "==> Start TMVAClassificationCategory" << std::endl;
56
57 // This loads the library
59
60 bool batchMode = false;
61
62 // Create a new root output file.
63 TString outfileName( "TMVA.root" );
64 TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
65
66 // Create the factory object (see TMVAClassification.C for more information)
67
68 std::string factoryOptions( "!V:!Silent:Transformations=I;D;P;G,D" );
69 if (batchMode) factoryOptions += ":!Color:!DrawProgressBar";
70
71 TMVA::Factory *factory = new TMVA::Factory( "TMVAClassificationCategory", outputFile, factoryOptions );
72
73 // Create DataLoader
74 TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
75
76 // Define the input variables used for the MVA training
77 dataloader->AddVariable( "var1", 'F' );
78 dataloader->AddVariable( "var2", 'F' );
79 dataloader->AddVariable( "var3", 'F' );
80 dataloader->AddVariable( "var4", 'F' );
81
82 // You can add so-called "Spectator variables", which are not used in the MVA training,
83 // but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
84 // input variables, the response values of all trained MVAs, and the spectator variables
85 dataloader->AddSpectator( "eta" );
86
87 // Load the signal and background event samples from ROOT trees
88 TFile *input(0);
89 TString fname = gSystem->GetDirName(__FILE__) + "/data/";
90 if (gSystem->AccessPathName( fname + "toy_sigbkg_categ_offset.root")) {
91 // if directory data not found try using tutorials dir
92 fname = gROOT->GetTutorialDir() + "/tmva/data/";
93 }
94 if (UseOffsetMethod) fname += "toy_sigbkg_categ_offset.root";
95 else fname += "toy_sigbkg_categ_varoff.root";
96 if (!gSystem->AccessPathName( fname )) {
97 // first we try to find tmva_example.root in the local directory
98 std::cout << "--- TMVAClassificationCategory: Accessing " << fname << std::endl;
99 input = TFile::Open( fname );
100 }
101
102 if (!input) {
103 std::cout << "ERROR: could not open data file: " << fname << std::endl;
104 exit(1);
105 }
106
107 TTree *signalTree = (TTree*)input->Get("TreeS");
108 TTree *background = (TTree*)input->Get("TreeB");
109
110 // Global event weights per tree (see below for setting event-wise weights)
111 Double_t signalWeight = 1.0;
112 Double_t backgroundWeight = 1.0;
113
114 // You can add an arbitrary number of signal or background trees
115 dataloader->AddSignalTree ( signalTree, signalWeight );
116 dataloader->AddBackgroundTree( background, backgroundWeight );
117
118 // Apply additional cuts on the signal and background samples (can be different)
119 TCut mycuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
120 TCut mycutb = ""; // for example: TCut mycutb = "abs(var1)<0.5";
121
122 // Tell the factory how to use the training and testing events
123 dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,
124 "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" );
125
126 // Book MVA methods
127
128 // Fisher discriminant
129 factory->BookMethod( dataloader, TMVA::Types::kFisher, "Fisher", "!H:!V:Fisher" );
130
131 // Likelihood
132 factory->BookMethod( dataloader, TMVA::Types::kLikelihood, "Likelihood",
133 "!H:!V:TransformOutput:PDFInterpol=Spline2:NSmoothSig[0]=20:NSmoothBkg[0]=20:NSmoothBkg[1]=10:NSmooth=1:NAvEvtPerBin=50" );
134
135 // Categorised classifier
136 TMVA::MethodCategory* mcat = 0;
137
138 // The variable sets
139 TString theCat1Vars = "var1:var2:var3:var4";
140 TString theCat2Vars = (UseOffsetMethod ? "var1:var2:var3:var4" : "var1:var2:var3");
141
142 // Fisher with categories
143 TMVA::MethodBase* fiCat = factory->BookMethod( dataloader, TMVA::Types::kCategory, "FisherCat","" );
144 mcat = dynamic_cast<TMVA::MethodCategory*>(fiCat);
145 mcat->AddMethod( "abs(eta)<=1.3", theCat1Vars, TMVA::Types::kFisher, "Category_Fisher_1","!H:!V:Fisher" );
146 mcat->AddMethod( "abs(eta)>1.3", theCat2Vars, TMVA::Types::kFisher, "Category_Fisher_2","!H:!V:Fisher" );
147
148 // Likelihood with categories
149 TMVA::MethodBase* liCat = factory->BookMethod( dataloader, TMVA::Types::kCategory, "LikelihoodCat","" );
150 mcat = dynamic_cast<TMVA::MethodCategory*>(liCat);
151 mcat->AddMethod( "abs(eta)<=1.3",theCat1Vars, TMVA::Types::kLikelihood,
152 "Category_Likelihood_1","!H:!V:TransformOutput:PDFInterpol=Spline2:NSmoothSig[0]=20:NSmoothBkg[0]=20:NSmoothBkg[1]=10:NSmooth=1:NAvEvtPerBin=50" );
153 mcat->AddMethod( "abs(eta)>1.3", theCat2Vars, TMVA::Types::kLikelihood,
154 "Category_Likelihood_2","!H:!V:TransformOutput:PDFInterpol=Spline2:NSmoothSig[0]=20:NSmoothBkg[0]=20:NSmoothBkg[1]=10:NSmooth=1:NAvEvtPerBin=50" );
155
156 // Now you can tell the factory to train, test, and evaluate the MVAs
157
158 // Train MVAs using the set of training events
159 factory->TrainAllMethods();
160
161 // Evaluate all MVAs using the set of test events
162 factory->TestAllMethods();
163
164 // Evaluate and compare performance of all configured MVAs
165 factory->EvaluateAllMethods();
166
167 // --------------------------------------------------------------
168
169 // Save the output
170 outputFile->Close();
171
172 std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
173 std::cout << "==> TMVAClassificationCategory is done!" << std::endl;
174
175 // Clean up
176 delete factory;
177 delete dataloader;
178
179 // Launch the GUI for the root macros
180 if (!gROOT->IsBatch()) TMVA::TMVAGui( outfileName );
181}
182int main( int argc, char** argv )
183{
184 TMVAClassificationCategory();
185 return 0;
186}
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define gROOT
Definition: TROOT.h:404
R__EXTERN TSystem * gSystem
Definition: TSystem.h:559
int main(int argc, char *argv[])
Definition: cef_main.cxx:54
A specialized string object used for TTree selections.
Definition: TCut.h:25
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:4011
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:889
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:371
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:524
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:632
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:402
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:485
This is the main MVA steering class.
Definition: Factory.h:80
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1114
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:352
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition: Factory.cxx:1271
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition: Factory.cxx:1376
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
Class for categorizing the phase space.
TMVA::IMethod * AddMethod(const TCut &, const TString &theVariables, Types::EMVA theMethod, const TString &theTitle, const TString &theOptions)
adds sub-classifier for a category
static Tools & Instance()
Definition: Tools.cxx:71
@ kFisher
Definition: Types.h:82
@ kCategory
Definition: Types.h:97
@ kLikelihood
Definition: Types.h:79
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Basic string class.
Definition: TString.h:136
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1296
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
Definition: TSystem.cxx:1032
A TTree represents a columnar dataset.
Definition: TTree.h:79
void TMVAGui(const char *fName="TMVA.root", TString dataset="")