Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
classification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva_envelope
3/// \notebook -nodraw
4
5/// \macro_code
6/// \macro_output
7
8
9
10#include "TMVA/Factory.h"
11#include "TMVA/DataLoader.h"
12#include "TMVA/Tools.h"
13#include "TROOT.h"
14#include "TMVA/Classification.h"
15
17{
19
20 TFile *input(nullptr);
21 TString fname = gROOT->GetTutorialDir() + "/machine_learning/data/tmva_class_example.root";
23 input = TFile::Open(fname); // check if file in local directory exists
24 }
25 if (!input) {
26 std::cout << "ERROR: could not open data file" << fname << std::endl;
27 exit(1);
28 }
29
30 // Register the training and test trees
31
32 TTree *signalTree = (TTree *)input->Get("TreeS");
33 TTree *background = (TTree *)input->Get("TreeB");
34
36 // If you wish to modify default settings
37 // (please check "src/Config.h" to see all available global options)
38 //
39 // (TMVA::gConfig().GetVariablePlotting()).fTimesRMS = 8.0;
40 // (TMVA::gConfig().GetIONames()).fWeightFileDir = "myWeightDirectory";
41
42 // Define the input variables that shall be used for the MVA training
43 // note that you may also use variable expressions, such as: "3*var1/var2*abs(var3)"
44 // [all types of expressions that can also be parsed by TTree::Draw( "expression" )]
45 dataloader->AddVariable("myvar1 := var1+var2", 'F');
46 dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
47 dataloader->AddVariable("var3", "Variable 3", "units", 'F');
48 dataloader->AddVariable("var4", "Variable 4", "units", 'F');
49
50 // You can add so-called "Spectator variables", which are not used in the MVA training,
51 // but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
52 // input variables, the response values of all trained MVAs, and the spectator variables
53
54 dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
55 dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
56
57 // global event weights per tree (see below for setting event-wise weights)
60
61 // You can add an arbitrary number of signal or background trees
62 dataloader->AddSignalTree(signalTree, signalWeight);
63 dataloader->AddBackgroundTree(background, backgroundWeight);
64
65 // Set individual event weights (the variables must exist in the original TTree)
66 // - for signal : `dataloader->SetSignalWeightExpression ("weight1*weight2");`
67 // - for background: `dataloader->SetBackgroundWeightExpression("weight1*weight2");`
68 dataloader->SetBackgroundWeightExpression("weight");
69 dataloader->PrepareTrainingAndTestTree(
70 "", "", "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V");
71
72 TFile *outputFile = TFile::Open("TMVAClass.root", "RECREATE");
73
75
76 cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
77 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
78 cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
79
80 cl->BookMethod(TMVA::Types::kBDT, "BDTB", "!H:!V:NTrees=2000:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
81
82 cl->BookMethod(TMVA::Types::kCuts, "Cuts", "!H:!V:FitMethod=MC:EffSel:SampleSize=200000:VarProp=FSmart");
83
84 cl->Evaluate(); // Train and Test all methods
85
86 auto &results = cl->GetResults();
87
88 TCanvas *c = new TCanvas(Form("ROC"));
89 c->SetTitle("ROC-Integral Curve");
90
91 auto mg = new TMultiGraph();
92 for (UInt_t i = 0; i < results.size(); i++) {
93 if (!results[i].IsCutsMethod()) {
94 auto roc = results[i].GetROCGraph();
95 roc->SetLineColorAlpha(i + 1, 0.1);
96 mg->Add(roc);
97 }
98 }
99 mg->Draw("AL");
100 mg->GetXaxis()->SetTitle(" Signal Efficiency ");
101 mg->GetYaxis()->SetTitle(" Background Rejection ");
102 c->BuildLegend(0.15, 0.15, 0.3, 0.3);
103 c->Draw();
104
105 outputFile->Close();
106 delete cl;
107}
#define c(i)
Definition RSha256.hxx:101
unsigned int UInt_t
Definition RtypesCore.h:46
double Double_t
Definition RtypesCore.h:59
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
#define gROOT
Definition TROOT.h:406
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
The Canvas class.
Definition TCanvas.h:23
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4131
virtual void BookMethod(TString methodname, TString methodtitle, TString options="")
Method to book the machine learning method to perform the algorithm.
Definition Envelope.cxx:163
std::vector< ClassificationResult > & GetResults()
Return the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
static Tools & Instance()
Definition Tools.cxx:71
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:34
Basic string class.
Definition TString.h:139
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1308
A TTree represents a columnar dataset.
Definition TTree.h:79
void classification(UInt_t jobs=4)