Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVAMinimalClassification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook -nodraw
4/// Minimal self-contained example for setting up TMVA with binary
5/// classification.
6///
7/// This is intended as a simple foundation to build on. It assumes you are
8/// familiar with TMVA already. As such concepts like the Factory, the DataLoader
9/// and others are not explained. For descriptions and tutorials use the TMVA online manual
10/// https://root.cern/manual/tmva/ or the more detailed examples provided with TMVA
11/// e.g. TMVAClassification.C. or the TMVA Users Guide
12/// https://github.com/root-project/root/blob/master/documentation/tmva/UsersGuide/TMVAUsersGuide.pdf
13///
14/// Sets up a minimal binary classification example with two slightly overlapping
15/// 2-D gaussian distributions and trains a BDT classifier to discriminate the
16/// data.
17///
18/// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
19/// - Package : TMVA
20/// - Root Macro: TMVAMinimalClassification.C
21///
22/// \macro_output
23/// \macro_code
24/// \author Kim Albertsson
25
26#include "TMVA/DataLoader.h"
27#include "TMVA/Factory.h"
28
29#include "TFile.h"
30#include "TString.h"
31#include "TTree.h"
32
33//
34// Helper function to generate 2-D gaussian data points and fill to a ROOT
35// TTree.
36//
37// Arguments:
38// nPoints Number of points to generate.
39// offset Mean of the generated numbers
40// scale Standard deviation of the generated numbers.
41// seed Seed for random number generator. Use `seed=0` for random
42// seed.
43// Returns a TTree ready to be used as input to TMVA.
44//
46{
47 TRandom rng(seed);
48 Double_t x = 0;
49 Double_t y = 0;
50
51 TTree *data = new TTree();
52 data->Branch("x", &x, "x/D");
53 data->Branch("y", &y, "y/D");
54
55 for (Int_t n = 0; n < nPoints; ++n) {
56 x = rng.Rndm() * scale;
57 y = offset + rng.Rndm() * scale;
58 data->Fill();
59 }
60
61 // Important: Disconnects the tree from the memory locations of x and y.
62 data->ResetBranchAddresses();
63 return data;
64}
65
66//
67// Minimal setup for performing binary classification in TMVA.
68//
69// Modify the setup to your liking and run with
70// `root -l -b -q TMVAMinimalClassification.C`.
71// This will generate an output file "out.root" that can be viewed with
72// `root -l -e 'TMVA::TMVAGui("out.root")'`.
73//
75{
76 TString outputFilename = "out.root";
77 TFile *outFile = new TFile(outputFilename, "RECREATE");
78
79 // Data generation
80 TTree *signalTree = genTree(1000, 0.0, 2.0, 100);
81 TTree *backgroundTree = genTree(1000, 1.0, 2.0, 101);
82
83 TString factoryOptions = "AnalysisType=Classification";
85
87
88 // Data specification
89 dataloader.AddVariable("x", 'D');
90 dataloader.AddVariable("y", 'D');
91
92 dataloader.AddSignalTree(signalTree, 1.0);
93 dataloader.AddBackgroundTree(backgroundTree, 1.0);
94
95 TCut signalCut = "";
97 TString datasetOptions = "SplitMode=Random";
98 dataloader.PrepareTrainingAndTestTree(signalCut, backgroundCut, datasetOptions);
99
100 // Method specification
102 factory.BookMethod(&dataloader, TMVA::Types::kBDT, "BDT", methodOptions);
103
104 // Training and Evaluation
105 factory.TrainAllMethods();
106 factory.TestAllMethods();
107 factory.EvaluateAllMethods();
108
109 // Clean up
110 outFile->Close();
111
112 delete outFile;
113 delete signalTree;
114 delete backgroundTree;
115}
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
double Double_t
Definition RtypesCore.h:59
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
A specialized string object used for TTree selections.
Definition TCut.h:25
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
This is the main MVA steering class.
Definition Factory.h:80
This is the base class for the ROOT Random number generators.
Definition TRandom.h:27
Basic string class.
Definition TString.h:139
A TTree represents a columnar dataset.
Definition TTree.h:79
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16