Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
classification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva_envelope
3/// \notebook -nodraw
4
5/// \macro_output
6/// \macro_code
7
8
9#include "TMVA/Factory.h"
10#include "TMVA/DataLoader.h"
11#include "TMVA/Tools.h"
12#include "TMVA/Classification.h"
13
14void classification(UInt_t jobs = 4)
15{
17
18 TFile *input(0);
19 TString fname = "./tmva_class_example.root";
20 if (!gSystem->AccessPathName(fname)) {
21 input = TFile::Open(fname); // check if file in local directory exists
22 } else {
24 input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
25 }
26 if (!input) {
27 std::cout << "ERROR: could not open data file" << std::endl;
28 exit(1);
29 }
30
31 // Register the training and test trees
32
33 TTree *signalTree = (TTree *)input->Get("TreeS");
34 TTree *background = (TTree *)input->Get("TreeB");
35
36 TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
37 // If you wish to modify default settings
38 // (please check "src/Config.h" to see all available global options)
39 //
40 // (TMVA::gConfig().GetVariablePlotting()).fTimesRMS = 8.0;
41 // (TMVA::gConfig().GetIONames()).fWeightFileDir = "myWeightDirectory";
42
43 // Define the input variables that shall be used for the MVA training
44 // note that you may also use variable expressions, such as: "3*var1/var2*abs(var3)"
45 // [all types of expressions that can also be parsed by TTree::Draw( "expression" )]
46 dataloader->AddVariable("myvar1 := var1+var2", 'F');
47 dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
48 dataloader->AddVariable("var3", "Variable 3", "units", 'F');
49 dataloader->AddVariable("var4", "Variable 4", "units", 'F');
50
51 // You can add so-called "Spectator variables", which are not used in the MVA training,
52 // but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
53 // input variables, the response values of all trained MVAs, and the spectator variables
54
55 dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
56 dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
57
58 // global event weights per tree (see below for setting event-wise weights)
59 Double_t signalWeight = 1.0;
60 Double_t backgroundWeight = 1.0;
61
62 // You can add an arbitrary number of signal or background trees
63 dataloader->AddSignalTree(signalTree, signalWeight);
64 dataloader->AddBackgroundTree(background, backgroundWeight);
65
66 // Set individual event weights (the variables must exist in the original TTree)
67 // - for signal : `dataloader->SetSignalWeightExpression ("weight1*weight2");`
68 // - for background: `dataloader->SetBackgroundWeightExpression("weight1*weight2");`
69 dataloader->SetBackgroundWeightExpression("weight");
71 "", "", "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V");
72
73 TFile *outputFile = TFile::Open("TMVAClass.root", "RECREATE");
74
76
77 cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
78 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
79 cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
80
81 cl->BookMethod(TMVA::Types::kBDT, "BDTB", "!H:!V:NTrees=2000:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
82
83 cl->BookMethod(TMVA::Types::kCuts, "Cuts", "!H:!V:FitMethod=MC:EffSel:SampleSize=200000:VarProp=FSmart");
84
85 cl->Evaluate(); // Train and Test all methods
86
87 auto &results = cl->GetResults();
88
89 TCanvas *c = new TCanvas(Form("ROC"));
90 c->SetTitle("ROC-Integral Curve");
91
92 auto mg = new TMultiGraph();
93 for (UInt_t i = 0; i < results.size(); i++) {
94 if (!results[i].IsCutsMethod()) {
95 auto roc = results[i].GetROCGraph();
96 roc->SetLineColorAlpha(i + 1, 0.1);
97 mg->Add(roc);
98 }
99 }
100 mg->Draw("AL");
101 mg->GetXaxis()->SetTitle(" Signal Efficiency ");
102 mg->GetYaxis()->SetTitle(" Background Rejection ");
103 c->BuildLegend(0.15, 0.15, 0.3, 0.3);
104 c->Draw();
105
106 outputFile->Close();
107 delete cl;
108}
#define c(i)
Definition RSha256.hxx:101
unsigned int UInt_t
Definition RtypesCore.h:46
double Double_t
Definition RtypesCore.h:59
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition TSystem.h:559
The Canvas class.
Definition TCanvas.h:23
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
static Bool_t SetCacheFileDir(ROOT::Internal::TStringView cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Definition TFile.h:324
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3997
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:879
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
void SetBackgroundWeightExpression(const TString &variable)
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition Envelope.cxx:163
std::vector< ClassificationResult > & GetResults()
return the the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
static Tools & Instance()
Definition Tools.cxx:75
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:36
Basic string class.
Definition TString.h:136
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1294
A TTree represents a columnar dataset.
Definition TTree.h:79
void classification(UInt_t jobs=4)