Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVAMulticlassApplication.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides a simple example on how to use the trained multiclass
5/// classifiers within an analysis module
6/// - Project : TMVA - a Root-integrated toolkit for multivariate data analysis
7/// - Package : TMVA
8/// - Root Macro: TMVAMulticlassApplication
9///
10/// \macro_output
11/// \macro_code
12/// \author Andreas Hoecker
13
14
15#include <cstdlib>
16#include <iostream>
17#include <map>
18#include <string>
19#include <vector>
20
21#include "TFile.h"
22#include "TTree.h"
23#include "TString.h"
24#include "TSystem.h"
25#include "TROOT.h"
26#include "TStopwatch.h"
27#include "TH1F.h"
28
29#include "TMVA/Tools.h"
30#include "TMVA/Reader.h"
31
32using namespace TMVA;
33
34void TMVAMulticlassApplication( TString myMethodList = "" )
35{
36
38
39 //---------------------------------------------------------------
40 // Default MVA methods to be trained + tested
41 std::map<std::string,int> Use;
42 Use["MLP"] = 1;
43 Use["BDTG"] = 1;
44 Use["DL_CPU"] = 1;
45 Use["DL_GPU"] = 1;
46 Use["FDA_GA"] = 1;
47 Use["PDEFoam"] = 1;
48 //---------------------------------------------------------------
49
50 std::cout << std::endl;
51 std::cout << "==> Start TMVAMulticlassApp" << std::endl;
52 if (myMethodList != "") {
53 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
54
55 std::vector<TString> mlist = gTools().SplitString( myMethodList, ',' );
56 for (UInt_t i=0; i<mlist.size(); i++) {
57 std::string regMethod(mlist[i]);
58
59 if (Use.find(regMethod) == Use.end()) {
60 std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
61 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " " << std::endl;
62 std::cout << std::endl;
63 return;
64 }
65 Use[regMethod] = 1;
66 }
67 }
68
69
70 // create the Reader object
71 TMVA::Reader *reader = new TMVA::Reader( "!Color:!Silent" );
72
73 // create a set of variables and declare them to the reader
74 // - the variable names must corresponds in name and type to
75 // those given in the weight file(s) that you use
76 Float_t var1, var2, var3, var4;
77 reader->AddVariable( "var1", &var1 );
78 reader->AddVariable( "var2", &var2 );
79 reader->AddVariable( "var3", &var3 );
80 reader->AddVariable( "var4", &var4 );
81
82 // book the MVA methods
83 TString dir = "dataset/weights/";
84 TString prefix = "TMVAMulticlass";
85
86 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
87 if (it->second) {
88 TString methodName = TString(it->first) + TString(" method");
89 TString weightfile = dir + prefix + TString("_") + TString(it->first) + TString(".weights.xml");
90 // check if file existing (i.e. method has been trained)
91 if (!gSystem->AccessPathName( weightfile ))
92 // file exists
93 reader->BookMVA( methodName, weightfile );
94 else {
95 std::cout << "TMVAMultiClassApplication: Skip " << methodName << " since it has not been trained !" << std::endl;
96 it->second = 0;
97 }
98 }
99 }
100
101 // book output histograms
102 UInt_t nbin = 100;
103 TH1F *histMLP_signal(0), *histBDTG_signal(0), *histFDAGA_signal(0), *histPDEFoam_signal(0);
104 TH1F *histDLCPU_signal(0), *histDLGPU_signal(0);
105 if (Use["MLP"])
106 histMLP_signal = new TH1F( "MVA_MLP_signal", "MVA_MLP_signal", nbin, 0., 1.1 );
107 if (Use["BDTG"])
108 histBDTG_signal = new TH1F( "MVA_BDTG_signal", "MVA_BDTG_signal", nbin, 0., 1.1 );
109 if (Use["DL_CPU"])
110 histDLCPU_signal = new TH1F("MVA_DLCPU_signal", "MVA_DLCPU_signal", nbin, 0., 1.1);
111 if (Use["DL_GPU"])
112 histDLGPU_signal = new TH1F("MVA_DLGPU_signal", "MVA_DLGPU_signal", nbin, 0., 1.1);
113 if (Use["FDA_GA"])
114 histFDAGA_signal = new TH1F( "MVA_FDA_GA_signal", "MVA_FDA_GA_signal", nbin, 0., 1.1 );
115 if (Use["PDEFoam"])
116 histPDEFoam_signal = new TH1F( "MVA_PDEFoam_signal", "MVA_PDEFoam_signal", nbin, 0., 1.1 );
117
118
119 TFile *input(0);
120 TString fname = "./tmva_example_multiclass.root";
121 if (!gSystem->AccessPathName( fname )) {
122 input = TFile::Open( fname ); // check if file in local directory exists
123 }
124 else {
126 input = TFile::Open("http://root.cern/files/tmva_multiclass_example.root", "CACHEREAD");
127 }
128 if (!input) {
129 std::cout << "ERROR: could not open data file" << std::endl;
130 exit(1);
131 }
132 std::cout << "--- TMVAMulticlassApp : Using input file: " << input->GetName() << std::endl;
133
134 // prepare the tree
135 // - here the variable names have to corresponds to your tree
136 // - you can use the same variables as above which is slightly faster,
137 // but of course you can use different ones and copy the values inside the event loop
138
139 TTree* theTree = (TTree*)input->Get("TreeS");
140 std::cout << "--- Select signal sample" << std::endl;
141 theTree->SetBranchAddress( "var1", &var1 );
142 theTree->SetBranchAddress( "var2", &var2 );
143 theTree->SetBranchAddress( "var3", &var3 );
144 theTree->SetBranchAddress( "var4", &var4 );
145
146 std::cout << "--- Processing: " << theTree->GetEntries() << " events" << std::endl;
147 TStopwatch sw;
148 sw.Start();
149
150 for (Long64_t ievt=0; ievt<theTree->GetEntries();ievt++) {
151 if (ievt%1000 == 0){
152 std::cout << "--- ... Processing event: " << ievt << std::endl;
153 }
154 theTree->GetEntry(ievt);
155
156 if (Use["MLP"])
157 histMLP_signal->Fill((reader->EvaluateMulticlass( "MLP method" ))[0]);
158 if (Use["BDTG"])
159 histBDTG_signal->Fill((reader->EvaluateMulticlass( "BDTG method" ))[0]);
160 if (Use["DL_CPU"])
161 histDLCPU_signal->Fill((reader->EvaluateMulticlass("DL_CPU method"))[0]);
162 if (Use["DL_GPU"])
163 histDLGPU_signal->Fill((reader->EvaluateMulticlass("DL_GPU method"))[0]);
164 if (Use["FDA_GA"])
165 histFDAGA_signal->Fill((reader->EvaluateMulticlass( "FDA_GA method" ))[0]);
166 if (Use["PDEFoam"])
167 histPDEFoam_signal->Fill((reader->EvaluateMulticlass( "PDEFoam method" ))[0]);
168
169 }
170
171 // get elapsed time
172 sw.Stop();
173 std::cout << "--- End of event loop: "; sw.Print();
174
175 TFile *target = new TFile( "TMVAMulticlassApp.root","RECREATE" );
176 if (Use["MLP"])
177 histMLP_signal->Write();
178 if (Use["BDTG"])
179 histBDTG_signal->Write();
180 if (Use["DL_CPU"])
181 histDLCPU_signal->Write();
182 if (Use["DL_GPU"])
183 histDLGPU_signal->Write();
184 if (Use["FDA_GA"])
185 histFDAGA_signal->Write();
186 if (Use["PDEFoam"])
187 histPDEFoam_signal->Write();
188
189 target->Close();
190 std::cout << "--- Created root file: \"TMVMulticlassApp.root\" containing the MVA output histograms" << std::endl;
191
192 delete reader;
193
194 std::cout << "==> TMVAMulticlassApp is done!" << std::endl << std::endl;
195}
196
197int main( int argc, char** argv )
198{
199 // Select methods (don't look at this code - not of interest)
200 TString methodList;
201 for (int i=1; i<argc; i++) {
202 TString regMethod(argv[i]);
203 if(regMethod=="-b" || regMethod=="--batch") continue;
204 if (!methodList.IsNull()) methodList += TString(",");
205 methodList += regMethod;
206 }
207 TMVAMulticlassApplication(methodList);
208 return 0;
209}
int main()
Definition Prototype.cxx:12
unsigned int UInt_t
Definition RtypesCore.h:46
float Float_t
Definition RtypesCore.h:57
long long Long64_t
Definition RtypesCore.h:80
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t target
R__EXTERN TSystem * gSystem
Definition TSystem.h:555
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4089
static Bool_t SetCacheFileDir(std::string_view cacheDir, Bool_t operateDisconnected=kTRUE, Bool_t forceCacheread=kFALSE)
Sets the directory where to locally stage/cache remote files.
Definition TFile.cxx:4626
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:621
The Reader class serves to use the MVAs in a specific analysis context.
Definition Reader.h:64
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition Reader.cxx:368
const std::vector< Float_t > & EvaluateMulticlass(const TString &methodTag, Double_t aux=0)
evaluates MVA for given set of input variables
Definition Reader.cxx:630
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition Reader.cxx:303
static Tools & Instance()
Definition Tools.cxx:71
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition Tools.cxx:1199
Stopwatch class.
Definition TStopwatch.h:28
void Start(Bool_t reset=kTRUE)
Start the stopwatch.
void Stop()
Stop the stopwatch.
void Print(Option_t *option="") const override
Print the real and cpu time passed between the start and stop events.
Basic string class.
Definition TString.h:139
Bool_t IsNull() const
Definition TString.h:414
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1296
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual Int_t GetEntry(Long64_t entry, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition TTree.cxx:5638
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=nullptr)
Change branch address, dealing with clone trees properly.
Definition TTree.cxx:8380
virtual Long64_t GetEntries() const
Definition TTree.h:463
create variable transformations
Tools & gTools()