Logo ROOT   6.14/05
Reference Guide
TMVAMinimalClassification.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// Minimal self-contained example for setting up TMVA with binary
5 /// classification.
6 ///
7 /// This is intended as a simple foundation to build on. It assumes you are
8 /// familiar with TMVA already. As such concepts like the Factory, the DataLoader
9 /// and others are not explained. For descriptions and tutuorials use the TMVA
10 /// User's Guide (https://root.cern.ch/root-user-guides-and-manuals under TMVA)
11 /// or the more detailed examples provided with TMVA e.g. TMVAClassification.C.
12 ///
13 /// Sets up a minimal binary classification example with two slighly overlapping
14 /// 2-D gaussian distributions and trains a BDT classifier to discriminate the
15 /// data.
16 ///
17 /// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
18 /// - Package : TMVA
19 /// - Root Macro: TMVAMinimalClassification.C
20 ///
21 /// \macro_output
22 /// \macro_code
23 /// \author Kim Albertsson
24 
25 #include "TMVA/DataLoader.h"
26 #include "TMVA/Factory.h"
27 
28 #include "TFile.h"
29 #include "TString.h"
30 #include "TTree.h"
31 
32 //
33 // Helper function to generate 2-D gaussian data points and fill to a ROOT
34 // TTree.
35 //
36 // Arguments:
37 // nPoints Number of points to generate.
38 // offset Mean of the generated numbers
39 // scale Standard deviation of the generated numbers.
40 // seed Seed for random number generator. Use `seed=0` for random
41 // seed.
42 // Returns a TTree ready to be used as input to TMVA.
43 //
44 TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
45 {
46  TRandom rng(seed);
47  Double_t x = 0;
48  Double_t y = 0;
49 
50  TTree *data = new TTree();
51  data->Branch("x", &x, "x/D");
52  data->Branch("y", &y, "y/D");
53 
54  for (Int_t n = 0; n < nPoints; ++n) {
55  x = rng.Rndm() * scale;
56  y = offset + rng.Rndm() * scale;
57  data->Fill();
58  }
59 
60  // Important: Disconnects the tree from the memory locations of x and y.
61  data->ResetBranchAddresses();
62  return data;
63 }
64 
65 //
66 // Minimal setup for perfroming binary classification in TMVA.
67 //
68 // Modify the setup to your liking and run with
69 // `root -l -b -q TMVAMinimalClassification.C`.
70 // This will generate an output file "out.root" that can be viewed with
71 // `root -l -e 'TMVA::TMVAGui("out.root")'`.
72 //
73 void TMVAMinimalClassification()
74 {
75  TString outputFilename = "out.root";
76  TFile *outFile = new TFile(outputFilename, "RECREATE");
77 
78  // Data generatration
79  TTree *signalTree = genTree(1000, 0.0, 2.0, 100);
80  TTree *backgroundTree = genTree(1000, 1.0, 2.0, 101);
81 
82  TString factoryOptions = "AnalysisType=Classification";
83  TMVA::Factory factory{"", outFile, factoryOptions};
84 
85  TMVA::DataLoader dataloader{"dataset"};
86 
87  // Data specification
88  dataloader.AddVariable("x", 'D');
89  dataloader.AddVariable("y", 'D');
90 
91  dataloader.AddSignalTree(signalTree, 1.0);
92  dataloader.AddBackgroundTree(backgroundTree, 1.0);
93 
94  TCut signalCut = "";
95  TCut backgroundCut = "";
96  TString datasetOptions = "SplitMode=Random";
97  dataloader.PrepareTrainingAndTestTree(signalCut, backgroundCut, datasetOptions);
98 
99  // Method specification
100  TString methodOptions = "";
101  factory.BookMethod(&dataloader, TMVA::Types::kBDT, "BDT", methodOptions);
102 
103  // Training and Evaluation
104  factory.TrainAllMethods();
105  factory.TestAllMethods();
106  factory.EvaluateAllMethods();
107 
108  // Clean up
109  outFile->Close();
110 
111  delete outFile;
112  delete signalTree;
113  delete backgroundTree;
114 }
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4374
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:47
Basic string class.
Definition: TString.h:131
int Int_t
Definition: RtypesCore.h:41
Double_t x[n]
Definition: legend1.C:17
This is the base class for the ROOT Random number generators.
Definition: TRandom.h:27
A specialized string object used for TTree selections.
Definition: TCut.h:25
unsigned int UInt_t
Definition: RtypesCore.h:42
This is the main MVA steering class.
Definition: Factory.h:81
double Double_t
Definition: RtypesCore.h:55
virtual void ResetBranchAddresses()
Tell all of our branches to drop their current objects and allocate new ones.
Definition: TTree.cxx:7714
Double_t y[n]
Definition: legend1.C:17
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1711
A TTree object has a header with a name and a title.
Definition: TTree.h:70
const Int_t n
Definition: legend1.C:16
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:917