Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVAMinimalClassification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// Minimal self-contained example for setting up TMVA with binary
5/// classification.
6///
7/// This is intended as a simple foundation to build on. It assumes you are
8/// familiar with TMVA already. As such concepts like the Factory, the DataLoader
9/// and others are not explained. For descriptions and tutorials use the TMVA
10/// User's Guide (https://root.cern/root-user-guides-and-manuals under TMVA)
11/// or the more detailed examples provided with TMVA e.g. TMVAClassification.C.
12///
13/// Sets up a minimal binary classification example with two slightly overlapping
14/// 2-D gaussian distributions and trains a BDT classifier to discriminate the
15/// data.
16///
17/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
18/// - Package : TMVA
19/// - Root Macro: TMVAMinimalClassification.C
20///
21/// \macro_output
22/// \macro_code
23/// \author Kim Albertsson
24
25#include "TMVA/DataLoader.h"
26#include "TMVA/Factory.h"
27
28#include "TFile.h"
29#include "TString.h"
30#include "TTree.h"
31
32//
33// Helper function to generate 2-D gaussian data points and fill to a ROOT
34// TTree.
35//
36// Arguments:
37// nPoints Number of points to generate.
38// offset Mean of the generated numbers
39// scale Standard deviation of the generated numbers.
40// seed Seed for random number generator. Use `seed=0` for random
41// seed.
42// Returns a TTree ready to be used as input to TMVA.
43//
44TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
45{
46 TRandom rng(seed);
47 Double_t x = 0;
48 Double_t y = 0;
49
50 TTree *data = new TTree();
51 data->Branch("x", &x, "x/D");
52 data->Branch("y", &y, "y/D");
53
54 for (Int_t n = 0; n < nPoints; ++n) {
55 x = rng.Rndm() * scale;
56 y = offset + rng.Rndm() * scale;
57 data->Fill();
58 }
59
60 // Important: Disconnects the tree from the memory locations of x and y.
61 data->ResetBranchAddresses();
62 return data;
63}
64
65//
66// Minimal setup for performing binary classification in TMVA.
67//
68// Modify the setup to your liking and run with
69// `root -l -b -q TMVAMinimalClassification.C`.
70// This will generate an output file "out.root" that can be viewed with
71// `root -l -e 'TMVA::TMVAGui("out.root")'`.
72//
73void TMVAMinimalClassification()
74{
75 TString outputFilename = "out.root";
76 TFile *outFile = new TFile(outputFilename, "RECREATE");
77
78 // Data generation
79 TTree *signalTree = genTree(1000, 0.0, 2.0, 100);
80 TTree *backgroundTree = genTree(1000, 1.0, 2.0, 101);
81
82 TString factoryOptions = "AnalysisType=Classification";
83 TMVA::Factory factory{"", outFile, factoryOptions};
84
85 TMVA::DataLoader dataloader{"dataset"};
86
87 // Data specification
88 dataloader.AddVariable("x", 'D');
89 dataloader.AddVariable("y", 'D');
90
91 dataloader.AddSignalTree(signalTree, 1.0);
92 dataloader.AddBackgroundTree(backgroundTree, 1.0);
93
94 TCut signalCut = "";
95 TCut backgroundCut = "";
96 TString datasetOptions = "SplitMode=Random";
97 dataloader.PrepareTrainingAndTestTree(signalCut, backgroundCut, datasetOptions);
98
99 // Method specification
100 TString methodOptions = "";
101 factory.BookMethod(&dataloader, TMVA::Types::kBDT, "BDT", methodOptions);
102
103 // Training and Evaluation
104 factory.TrainAllMethods();
105 factory.TestAllMethods();
106 factory.EvaluateAllMethods();
107
108 // Clean up
109 outFile->Close();
110
111 delete outFile;
112 delete signalTree;
113 delete backgroundTree;
114}
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
double Double_t
Definition RtypesCore.h:59
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
A specialized string object used for TTree selections.
Definition TCut.h:25
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:53
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:928
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
This is the main MVA steering class.
Definition Factory.h:80
This is the base class for the ROOT Random number generators.
Definition TRandom.h:27
Basic string class.
Definition TString.h:139
A TTree represents a columnar dataset.
Definition TTree.h:79
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16