Logo ROOT  
Reference Guide
TMVAMultipleBackgroundExample.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This example shows the training of signal with three different backgrounds
5/// Then in the application a tree is created with all signal and background
6/// events where the true class ID and the three classifier outputs are added
7/// finally with the application tree, the significance is maximized with the
8/// help of the TMVA genetic algorithm.
9/// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis
10/// - Package : TMVA
11/// - Executable: TMVAGAexample
12///
13/// \macro_output
14/// \macro_code
15/// \author Andreas Hoecker
16
17
18#include <iostream> // Stream declarations
19#include <vector>
20#include <limits>
21
22#include "TChain.h"
23#include "TCut.h"
24#include "TDirectory.h"
25#include "TH1F.h"
26#include "TH1.h"
27#include "TMath.h"
28#include "TFile.h"
29#include "TStopwatch.h"
30#include "TROOT.h"
31#include "TSystem.h"
32
34#include "TMVA/GeneticFitter.h"
35#include "TMVA/IFitterTarget.h"
36#include "TMVA/Factory.h"
37#include "TMVA/DataLoader.h"//required to load dataset
38#include "TMVA/Reader.h"
39
40using namespace std;
41
42using namespace TMVA;
43
44// ----------------------------------------------------------------------------------------------
45// Training
46// ----------------------------------------------------------------------------------------------
47//
48void Training(){
49 std::string factoryOptions( "!V:!Silent:Transformations=I;D;P;G,D:AnalysisType=Classification" );
50 TString fname = "./tmva_example_multiple_background.root";
51
52 TFile *input(0);
53 input = TFile::Open( fname );
54
55 TTree *signal = (TTree*)input->Get("TreeS");
56 TTree *background0 = (TTree*)input->Get("TreeB0");
57 TTree *background1 = (TTree*)input->Get("TreeB1");
58 TTree *background2 = (TTree*)input->Get("TreeB2");
59
60 /// global event weights per tree (see below for setting event-wise weights)
61 Double_t signalWeight = 1.0;
62 Double_t background0Weight = 1.0;
63 Double_t background1Weight = 1.0;
64 Double_t background2Weight = 1.0;
65
66 // Create a new root output file.
67 TString outfileName( "TMVASignalBackground0.root" );
68 TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
69
70
71
72 // background 0
73 // ____________
74 TMVA::Factory *factory = new TMVA::Factory( "TMVAMultiBkg0", outputFile, factoryOptions );
75 TMVA::DataLoader *dataloader=new TMVA::DataLoader("datasetBkg0");
76
77 dataloader->AddVariable( "var1", "Variable 1", "", 'F' );
78 dataloader->AddVariable( "var2", "Variable 2", "", 'F' );
79 dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
80 dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
81
82 dataloader->AddSignalTree ( signal, signalWeight );
83 dataloader->AddBackgroundTree( background0, background0Weight );
84
85 // factory->SetBackgroundWeightExpression("weight");
86 TCut mycuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
87 TCut mycutb = ""; // for example: TCut mycutb = "abs(var1)<0.5";
88
89 // tell the factory to use all remaining events in the trees after training for testing:
90 dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,
91 "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" );
92
93 // Boosted Decision Trees
94 factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG",
95 "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.30:UseBaggedBoost:BaggedSampleFraction=0.6:SeparationType=GiniIndex:nCuts=20:MaxDepth=2" );
96 factory->TrainAllMethods();
97 factory->TestAllMethods();
98 factory->EvaluateAllMethods();
99
100 outputFile->Close();
101
102 delete factory;
103 delete dataloader;
104
105
106
107 // background 1
108 // ____________
109
110 outfileName = "TMVASignalBackground1.root";
111 outputFile = TFile::Open( outfileName, "RECREATE" );
112 dataloader=new TMVA::DataLoader("datasetBkg1");
113
114 factory = new TMVA::Factory( "TMVAMultiBkg1", outputFile, factoryOptions );
115 dataloader->AddVariable( "var1", "Variable 1", "", 'F' );
116 dataloader->AddVariable( "var2", "Variable 2", "", 'F' );
117 dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
118 dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
119
120 dataloader->AddSignalTree ( signal, signalWeight );
121 dataloader->AddBackgroundTree( background1, background1Weight );
122
123 // dataloader->SetBackgroundWeightExpression("weight");
124
125 // tell the factory to use all remaining events in the trees after training for testing:
126 dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,
127 "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" );
128
129 // Boosted Decision Trees
130 factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG",
131 "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.30:UseBaggedBoost:BaggedSampleFraction=0.6:SeparationType=GiniIndex:nCuts=20:MaxDepth=2" );
132 factory->TrainAllMethods();
133 factory->TestAllMethods();
134 factory->EvaluateAllMethods();
135
136 outputFile->Close();
137
138 delete factory;
139 delete dataloader;
140
141
142 // background 2
143 // ____________
144
145 outfileName = "TMVASignalBackground2.root";
146 outputFile = TFile::Open( outfileName, "RECREATE" );
147
148 factory = new TMVA::Factory( "TMVAMultiBkg2", outputFile, factoryOptions );
149 dataloader=new TMVA::DataLoader("datasetBkg2");
150
151 dataloader->AddVariable( "var1", "Variable 1", "", 'F' );
152 dataloader->AddVariable( "var2", "Variable 2", "", 'F' );
153 dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
154 dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
155
156 dataloader->AddSignalTree ( signal, signalWeight );
157 dataloader->AddBackgroundTree( background2, background2Weight );
158
159 // dataloader->SetBackgroundWeightExpression("weight");
160
161 // tell the dataloader to use all remaining events in the trees after training for testing:
162 dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,
163 "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" );
164
165 // Boosted Decision Trees
166 factory->BookMethod( dataloader, TMVA::Types::kBDT, "BDTG",
167 "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.30:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20:MaxDepth=2" );
168 factory->TrainAllMethods();
169 factory->TestAllMethods();
170 factory->EvaluateAllMethods();
171
172 outputFile->Close();
173
174 delete factory;
175 delete dataloader;
176
177}
178
179
180
181
182
183// ----------------------------------------------------------------------------------------------
184// Application
185// ----------------------------------------------------------------------------------------------
186//
187// create a summary tree with all signal and background events and for each event the three classifier values and the true classID
188void ApplicationCreateCombinedTree(){
189
190 // Create a new root output file.
191 TString outfileName( "tmva_example_multiple_backgrounds__applied.root" );
192 TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
193 TTree* outputTree = new TTree("multiBkg","multiple backgrounds tree");
194
195 Float_t var1, var2;
196 Float_t var3, var4;
197 Int_t classID = 0;
198 Float_t weight = 1.f;
199
200 Float_t classifier0, classifier1, classifier2;
201
202 outputTree->Branch("classID", &classID, "classID/I");
203 outputTree->Branch("var1", &var1, "var1/F");
204 outputTree->Branch("var2", &var2, "var2/F");
205 outputTree->Branch("var3", &var3, "var3/F");
206 outputTree->Branch("var4", &var4, "var4/F");
207 outputTree->Branch("weight", &weight, "weight/F");
208 outputTree->Branch("cls0", &classifier0, "cls0/F");
209 outputTree->Branch("cls1", &classifier1, "cls1/F");
210 outputTree->Branch("cls2", &classifier2, "cls2/F");
211
212
213 // create three readers for the three different signal/background classifications, .. one for each background
214 TMVA::Reader *reader0 = new TMVA::Reader( "!Color:!Silent" );
215 TMVA::Reader *reader1 = new TMVA::Reader( "!Color:!Silent" );
216 TMVA::Reader *reader2 = new TMVA::Reader( "!Color:!Silent" );
217
218 reader0->AddVariable( "var1", &var1 );
219 reader0->AddVariable( "var2", &var2 );
220 reader0->AddVariable( "var3", &var3 );
221 reader0->AddVariable( "var4", &var4 );
222
223 reader1->AddVariable( "var1", &var1 );
224 reader1->AddVariable( "var2", &var2 );
225 reader1->AddVariable( "var3", &var3 );
226 reader1->AddVariable( "var4", &var4 );
227
228 reader2->AddVariable( "var1", &var1 );
229 reader2->AddVariable( "var2", &var2 );
230 reader2->AddVariable( "var3", &var3 );
231 reader2->AddVariable( "var4", &var4 );
232
233 // load the weight files for the readers
234 TString method = "BDT method";
235 reader0->BookMVA( "BDT method", "datasetBkg0/weights/TMVAMultiBkg0_BDTG.weights.xml" );
236 reader1->BookMVA( "BDT method", "datasetBkg1/weights/TMVAMultiBkg1_BDTG.weights.xml" );
237 reader2->BookMVA( "BDT method", "datasetBkg2/weights/TMVAMultiBkg2_BDTG.weights.xml" );
238
239 // load the input file
240 TFile *input(0);
241 TString fname = "./tmva_example_multiple_background.root";
242 input = TFile::Open( fname );
243
244 TTree* theTree = NULL;
245
246 // loop through signal and all background trees
247 for( int treeNumber = 0; treeNumber < 4; ++treeNumber ) {
248 if( treeNumber == 0 ){
249 theTree = (TTree*)input->Get("TreeS");
250 std::cout << "--- Select signal sample" << std::endl;
251// theTree->SetBranchAddress( "weight", &weight );
252 weight = 1;
253 classID = 0;
254 }else if( treeNumber == 1 ){
255 theTree = (TTree*)input->Get("TreeB0");
256 std::cout << "--- Select background 0 sample" << std::endl;
257// theTree->SetBranchAddress( "weight", &weight );
258 weight = 1;
259 classID = 1;
260 }else if( treeNumber == 2 ){
261 theTree = (TTree*)input->Get("TreeB1");
262 std::cout << "--- Select background 1 sample" << std::endl;
263// theTree->SetBranchAddress( "weight", &weight );
264 weight = 1;
265 classID = 2;
266 }else if( treeNumber == 3 ){
267 theTree = (TTree*)input->Get("TreeB2");
268 std::cout << "--- Select background 2 sample" << std::endl;
269// theTree->SetBranchAddress( "weight", &weight );
270 weight = 1;
271 classID = 3;
272 }
273
274
275 theTree->SetBranchAddress( "var1", &var1 );
276 theTree->SetBranchAddress( "var2", &var2 );
277 theTree->SetBranchAddress( "var3", &var3 );
278 theTree->SetBranchAddress( "var4", &var4 );
279
280
281 std::cout << "--- Processing: " << theTree->GetEntries() << " events" << std::endl;
282 TStopwatch sw;
283 sw.Start();
284 Int_t nEvent = theTree->GetEntries();
285// Int_t nEvent = 100;
286 for (Long64_t ievt=0; ievt<nEvent; ievt++) {
287
288 if (ievt%1000 == 0){
289 std::cout << "--- ... Processing event: " << ievt << std::endl;
290 }
291
292 theTree->GetEntry(ievt);
293
294 // get the classifiers for each of the signal/background classifications
295 classifier0 = reader0->EvaluateMVA( method );
296 classifier1 = reader1->EvaluateMVA( method );
297 classifier2 = reader2->EvaluateMVA( method );
298
299 outputTree->Fill();
300 }
301
302
303 // get elapsed time
304 sw.Stop();
305 std::cout << "--- End of event loop: "; sw.Print();
306 }
307 input->Close();
308
309
310 // write output tree
311/* outputTree->SetDirectory(outputFile);
312 outputTree->Write(); */
313 outputFile->Write();
314
315 outputFile->Close();
316
317 std::cout << "--- Created root file: \"" << outfileName.Data() << "\" containing the MVA output histograms" << std::endl;
318
319 delete reader0;
320 delete reader1;
321 delete reader2;
322
323 std::cout << "==> Application of readers is done! combined tree created" << std::endl << std::endl;
324
325}
326
327
328
329
330// -----------------------------------------------------------------------------------------
331// Genetic Algorithm Fitness definition
332// -----------------------------------------------------------------------------------------
333//
334class MyFitness : public IFitterTarget {
335public:
336 // constructor
337 MyFitness( TChain* _chain ) : IFitterTarget() {
338 chain = _chain;
339
340 hSignal = new TH1F("hsignal","hsignal",100,-1,1);
341 hFP = new TH1F("hfp","hfp",100,-1,1);
342 hTP = new TH1F("htp","htp",100,-1,1);
343
344 TString cutsAndWeightSignal = "weight*(classID==0)";
345 nSignal = chain->Draw("Entry$/Entries$>>hsignal",cutsAndWeightSignal,"goff");
346 weightsSignal = hSignal->Integral();
347
348 }
349
350 // the output of this function will be minimized
351 Double_t EstimatorFunction( std::vector<Double_t> & factors ){
352
353 TString cutsAndWeightTruePositive = Form("weight*((classID==0) && cls0>%f && cls1>%f && cls2>%f )",factors.at(0), factors.at(1), factors.at(2));
354 TString cutsAndWeightFalsePositive = Form("weight*((classID >0) && cls0>%f && cls1>%f && cls2>%f )",factors.at(0), factors.at(1), factors.at(2));
355
356 // Entry$/Entries$ just draws something reasonable. Could in principle anything
357 Float_t nTP = chain->Draw("Entry$/Entries$>>htp",cutsAndWeightTruePositive,"goff");
358 Float_t nFP = chain->Draw("Entry$/Entries$>>hfp",cutsAndWeightFalsePositive,"goff");
359
360 weightsTruePositive = hTP->Integral();
361 weightsFalsePositive = hFP->Integral();
362
363 efficiency = 0;
364 if( weightsSignal > 0 )
365 efficiency = weightsTruePositive/weightsSignal;
366
367 purity = 0;
368 if( weightsTruePositive+weightsFalsePositive > 0 )
369 purity = weightsTruePositive/(weightsTruePositive+weightsFalsePositive);
370
371 Float_t effTimesPur = efficiency*purity;
372
373 Float_t toMinimize = std::numeric_limits<float>::max(); // set to the highest existing number
374 if( effTimesPur > 0 ) // if larger than 0, take 1/x. This is the value to minimize
375 toMinimize = 1./(effTimesPur); // we want to minimize 1/efficiency*purity
376
377 // Print();
378
379 return toMinimize;
380 }
381
382
383 void Print(){
384 std::cout << std::endl;
385 std::cout << "======================" << std::endl
386 << "Efficiency : " << efficiency << std::endl
387 << "Purity : " << purity << std::endl << std::endl
388 << "True positive weights : " << weightsTruePositive << std::endl
389 << "False positive weights: " << weightsFalsePositive << std::endl
390 << "Signal weights : " << weightsSignal << std::endl;
391 }
392
393 Float_t nSignal;
394
395 Float_t efficiency;
396 Float_t purity;
397 Float_t weightsTruePositive;
398 Float_t weightsFalsePositive;
399 Float_t weightsSignal;
400
401
402private:
403 TChain* chain;
404 TH1F* hSignal;
405 TH1F* hFP;
406 TH1F* hTP;
407
408};
409
410
411
412
413
414
415
416
417// ----------------------------------------------------------------------------------------------
418// Call of Genetic algorithm
419// ----------------------------------------------------------------------------------------------
420//
421void MaximizeSignificance(){
422
423 // define all the parameters by their minimum and maximum value
424 // in this example 3 parameters (=cuts on the classifiers) are defined.
425 vector<Interval*> ranges;
426 ranges.push_back( new Interval(-1,1) ); // for some classifiers (especially LD) the ranges have to be taken larger
427 ranges.push_back( new Interval(-1,1) );
428 ranges.push_back( new Interval(-1,1) );
429
430 std::cout << "Classifier ranges (defined by the user)" << std::endl;
431 for( std::vector<Interval*>::iterator it = ranges.begin(); it != ranges.end(); it++ ){
432 std::cout << " range: " << (*it)->GetMin() << " " << (*it)->GetMax() << std::endl;
433 }
434
435 TChain* chain = new TChain("multiBkg");
436 chain->Add("tmva_example_multiple_backgrounds__applied.root");
437
438 IFitterTarget* myFitness = new MyFitness( chain );
439
440 // prepare the genetic algorithm with an initial population size of 20
441 // mind: big population sizes will help in searching the domain space of the solution
442 // but you have to weight this out to the number of generations
443 // the extreme case of 1 generation and populationsize n is equal to
444 // a Monte Carlo calculation with n tries
445
446 const TString name( "multipleBackgroundGA" );
447 const TString opts( "PopSize=100:Steps=30" );
448
449 GeneticFitter mg( *myFitness, name, ranges, opts);
450 // mg.SetParameters( 4, 30, 200, 10,5, 0.95, 0.001 );
451
452 std::vector<Double_t> result;
453 Double_t estimator = mg.Run(result);
454
455 dynamic_cast<MyFitness*>(myFitness)->Print();
456 std::cout << std::endl;
457
458 int n = 0;
459 for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); it++ ){
460 std::cout << " cutValue[" << n << "] = " << (*it) << ";"<< std::endl;
461 n++;
462 }
463
464
465}
466
467
468
469
470void TMVAMultipleBackgroundExample()
471{
472 // ----------------------------------------------------------------------------------------
473 // Run all
474 // ----------------------------------------------------------------------------------------
475 cout << "Start Test TMVAGAexample" << endl
476 << "========================" << endl
477 << endl;
478
479 TString createDataMacro = gROOT->GetTutorialDir() + "/tmva/createData.C";
480 gROOT->ProcessLine(TString::Format(".L %s",createDataMacro.Data()));
481 gROOT->ProcessLine("create_MultipleBackground(200)");
482
483
484 cout << endl;
485 cout << "========================" << endl;
486 cout << "--- Training" << endl;
487 Training();
488
489 cout << endl;
490 cout << "========================" << endl;
491 cout << "--- Application & create combined tree" << endl;
492 ApplicationCreateCombinedTree();
493
494 cout << endl;
495 cout << "========================" << endl;
496 cout << "--- maximize significance" << endl;
497 MaximizeSignificance();
498}
499
500int main( int argc, char** argv ) {
501 TMVAMultipleBackgroundExample();
502}
int Int_t
Definition: RtypesCore.h:45
double Double_t
Definition: RtypesCore.h:59
long long Long64_t
Definition: RtypesCore.h:80
float Float_t
Definition: RtypesCore.h:57
char name[80]
Definition: TGX11.cxx:110
#define gROOT
Definition: TROOT.h:404
char * Form(const char *fmt,...)
int main(int argc, char *argv[])
Definition: cef_main.cxx:54
A chain is a collection of files containing TTree objects.
Definition: TChain.h:33
virtual Int_t Add(TChain *chain)
Add all files referenced by the passed chain to this chain.
Definition: TChain.cxx:230
A specialized string object used for TTree selections.
Definition: TCut.h:25
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:4011
Int_t Write(const char *name=nullptr, Int_t opt=0, Int_t bufsiz=0) override
Write memory objects to this file.
Definition: TFile.cxx:2362
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:889
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:575
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:371
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:632
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:402
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:485
This is the main MVA steering class.
Definition: Factory.h:80
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1114
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:352
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition: Factory.cxx:1271
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition: Factory.cxx:1376
Fitter using a Genetic Algorithm.
Definition: GeneticFitter.h:44
Interface for a fitter 'target'.
Definition: IFitterTarget.h:44
virtual Double_t EstimatorFunction(std::vector< Double_t > &parameters)=0
The TMVA::Interval Class.
Definition: Interval.h:61
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:64
Double_t EvaluateMVA(const std::vector< Float_t > &, const TString &methodTag, Double_t aux=0)
Evaluate a std::vector<float> of input data for a given method The parameter aux is obligatory for th...
Definition: Reader.cxx:468
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:368
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:303
@ kBDT
Definition: Types.h:86
Stopwatch class.
Definition: TStopwatch.h:28
void Start(Bool_t reset=kTRUE)
Start the stopwatch.
Definition: TStopwatch.cxx:58
void Stop()
Stop the stopwatch.
Definition: TStopwatch.cxx:77
void Print(Option_t *option="") const
Print the real and cpu time passed between the start and stop events.
Definition: TStopwatch.cxx:219
Basic string class.
Definition: TString.h:136
const char * Data() const
Definition: TString.h:369
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2336
A TTree represents a columnar dataset.
Definition: TTree.h:79
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4572
virtual Int_t GetEntry(Long64_t entry, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5606
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:8340
virtual Long64_t GetEntries() const
Definition: TTree.h:459
TBranch * Branch(const char *name, T *obj, Int_t bufsize=32000, Int_t splitlevel=99)
Add a new branch, and infer the data type from the type of obj being passed.
Definition: TTree.h:350
const Int_t n
Definition: legend1.C:16
void Print(std::ostream &os, const OptionType &opt)
static constexpr double mg
create variable transformations