Logo ROOT  
Reference Guide
efficienciesMulticlass.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Kim Albertsson
3/**********************************************************************************
4 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
5 * Package: TMVAGUI *
6 * Web : http://tmva.sourceforge.net *
7 * *
8 * Description: *
9 * Implementation (see header for description) *
10 * *
11 * Authors : *
12 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
13 * *
14 * Copyright (c) 2005-2017: *
15 * CERN, Switzerland *
16 * LTU, Sweden *
17 * *
18 * Redistribution and use in source and binary forms, with or without *
19 * modification, are permitted according to the terms listed in LICENSE *
20 * (http://tmva.sourceforge.net/LICENSE) *
21 **********************************************************************************/
22
24
25// TMVA
26#include "TMVA/Config.h"
28#include "TMVA/tmvaglob.h"
29
30// ROOT
31#include "TControlBar.h"
32#include "TFile.h"
33#include "TGraph.h"
34#include "TH2F.h"
35#include "TIterator.h"
36#include "TKey.h"
37#include "TROOT.h"
38
39// STL
40#include <iostream>
41
42////////////////////////////////////////////////////////////////////////////////
43///
44/// Note: This file assumes a certain structure on the input file. The structure
45/// is as follows:
46///
47/// - dataset (TDirectory)
48/// - ... some variables, plots ...
49/// - Method_XXX (TDirectory)
50/// + XXX (TDirectory)
51/// * ... some plots ...
52/// * MVA_Method_XXX_Test_#classname#
53/// * MVA_Method_XXX_Train_#classname#
54/// * ... some plots ...
55/// - Method_YYY (TDirectory)
56/// + YYY (TDirectory)
57/// * ... some plots ...
58/// * MVA_Method_YYY_Test_#classname#
59/// * MVA_Method_YYY_Train_#classname#
60/// * ... some plots ...
61/// - TestTree (TTree)
62/// + ... data...
63/// - TrainTree (TTree)
64/// + ... data...
65///
66/// Keeping this in mind makes the main loop in getRocCurves easier to follow :)
67///
68
69////////////////////////////////////////////////////////////////////////////////
70/// Private class that simplify drawing plots combining information from
71/// several methods.
72///
73/// Each wrapper will manage a canvas and a legend and provide convenience
74/// functions to add data to these. It also provides a save function for
75/// saving an image representation to disk.
76///
77/// Feel free to extend this class as you see fit. It is intended as a
78/// convenience when showing multiclass roccurves, not a fully general tool.
79///
80/// Usage:
81/// auto p = new EfficiencyPlotWrapper(name, title, dataset, i):
82/// for (TGraph * g : listOfGraphs) {
83/// p->AddGraph(g);
84/// p->AddLegendEntry(methodName);
85/// }
86/// p->save();
87///
88
89class EfficiencyPlotWrapper {
90public:
91 TCanvas *fCanvas;
92 TLegend *fLegend;
93
94 TString fDataset;
95
96 Int_t fColor;
97
98 UInt_t fNumMethods;
99
100 EfficiencyPlotWrapper(TString name, TString title, TString dataset, size_t i);
101
102 Int_t addGraph(TGraph *graph);
103 void addLegendEntry(TString methodTitle, TGraph *graph);
104
105 void save();
106
107private:
108 Float_t fx0L;
109 Float_t fdxL;
110 Float_t fy0H;
111 Float_t fdyH;
112
113 TCanvas *newEfficiencyCanvas(TString name, TString title, size_t i);
114 TLegend *newEfficiencyLegend();
115};
116
117using classcanvasmap_t = std::map<TString, EfficiencyPlotWrapper *>;
118using roccurvelist_t = std::vector<std::tuple<TString, TString, TGraph *>>;
119
120// Constants
121const char *BUTTON_TYPE = "button";
122
123// Private functions
124namespace TMVA {
125std::vector<TString> getclassnames(TString dataset, TString fin);
126roccurvelist_t getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef);
128}
129
130////////////////////////////////////////////////////////////////////////////////
131/// Private (helper) functions - Implementation
132////////////////////////////////////////////////////////////////////////////////
133
134////////////////////////////////////////////////////////////////////////////////
135///
136
137std::vector<TString> TMVA::getclassnames(TString dataset, TString fin)
138{
140 TDirectory *dir = (TDirectory *)file->GetDirectory(dataset)->GetDirectory("InputVariables_Id");
141 if (!dir) {
142 std::cout << "Could not locate directory '" << dataset << "/InputVariables_Id' in file: " << fin << std::endl;
143 return {};
144 }
145
146 auto classnames = TMVA::TMVAGlob::GetClassNames(dir);
147 return classnames;
148}
149
150////////////////////////////////////////////////////////////////////////////////
151///
152
153roccurvelist_t TMVA::getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef)
154{
155 roccurvelist_t rocCurves;
156
157 TList methods;
158 UInt_t nm = TMVAGlob::GetListOfMethods(methods, binDir);
159 if (nm == 0) {
160 cout << "ups .. no methods found in to plot ROC curve for ... give up" << endl;
161 return rocCurves;
162 }
163 // TIter next(file->GetListOfKeys());
164 TIter next(&methods);
165
166 // Loop over all method categories
167 TKey *key;
168 while ((key = (TKey *)next())) {
169 TDirectory *mDir = (TDirectory *)key->ReadObj();
170 TList titles;
171 TMVAGlob::GetListOfTitles(mDir, titles);
172
173 // Loop over each method within a category
174 TIter nextTitle(&titles);
175 TKey *titkey;
176 TDirectory *titDir;
177 while ((titkey = TMVAGlob::NextKey(nextTitle, "TDirectory"))) {
178 titDir = (TDirectory *)titkey->ReadObj();
179 TString methodTitle;
180 TMVAGlob::GetMethodTitle(methodTitle, titDir);
181
182 // Loop through all plots for the method
183 TIter nextKey(titDir->GetListOfKeys());
184 TKey *hkey2;
185 while ((hkey2 = TMVAGlob::NextKey(nextKey, "TGraph"))) {
186
187 TGraph *h = (TGraph *)hkey2->ReadObj();
188 TString hname = h->GetName();
189 if (hname.Contains(graphNameRef) && hname.BeginsWith(methodPrefix) && !hname.Contains("Train")) {
190
191 // Extract classname from plot name
192 UInt_t index = hname.Last('_');
193 TString classname = hname(index + 1, hname.Length() - (index + 1));
194
195 rocCurves.push_back(std::make_tuple(methodTitle, classname, h));
196 }
197 }
198 }
199 }
200 return rocCurves;
201}
202
203////////////////////////////////////////////////////////////////////////////////
204/// Public functions - Implementation
205////////////////////////////////////////////////////////////////////////////////
206
207////////////////////////////////////////////////////////////////////////////////
208/// Private convenience function.
209///
210/// Adds a given a list of roc curves provided as n-tuple on the form
211/// (methodname, classname, graph)
212/// to the canvas corresponding to the classname.
213///
214
216{
217 for (auto &item : rocCurves) {
218
219 TString methodTitle = std::get<0>(item);
220 TString classname = std::get<1>(item);
221 TGraph *h = std::get<2>(item);
222
223 try {
224 EfficiencyPlotWrapper *plotWrapper = classCanvasMap.at(classname);
225 plotWrapper->addGraph(h);
226 plotWrapper->addLegendEntry(methodTitle, h);
227 } catch (const std::out_of_range &oor) {
228 cout << Form("ERROR: Class %s discovered among plots but was not found by TMVAMulticlassGui. Skipping.",
229 classname.Data())
230 << endl;
231 }
232 }
233}
234
235////////////////////////////////////////////////////////////////////////////////
236/// Entry point. Called from the TMVAMulticlassGui Buttons
237///
238/// \param dataset Dataset to operate on. Should be created by the TMVA Multiclass Factory.
239/// \param filename_input Name of the input file procuded by a TMVA Multiclass Factory.
240/// \param plotType Specified what kind of ROC curve to draw. Currently only rejB vs. effS is supported.
241
242void TMVA::efficienciesMulticlass1vsRest(TString dataset, TString filename_input, EEfficiencyPlotType plotType,
243 Bool_t useTMVAStyle)
244{
245 // set style and remove existing canvas'
246 TMVAGlob::Initialize(useTMVAStyle);
247 plotEfficienciesMulticlass1vsRest(dataset, plotType, filename_input);
248 return;
249}
250
251////////////////////////////////////////////////////////////////////////////////
252/// Work horse function. Will operate on the currently open file (opened by
253/// efficienciesMulticlass).
254///
255/// \param plotType See effcienciesMulticlass.
256/// \param binDir Directory in the file on which to operate.
257
258void TMVA::plotEfficienciesMulticlass1vsRest(TString dataset, EEfficiencyPlotType plotType, TString filename_input)
259{
260 // The current multiclass version implements only type 2 - rejB vs effS
261 if (plotType != EEfficiencyPlotType::kRejBvsEffS) {
262 std::cout << "For multiclass, only rejB vs effS is currently implemented.";
263 return;
264 }
265
266 // checks if filename_input is already open, and if not opens one
267 TFile *file = TMVAGlob::OpenFile(filename_input);
268 if (file == nullptr) {
269 std::cout << "ERROR: filename \"" << filename_input << "\" is not found.";
270 return;
271 }
272 auto binDir = file->GetDirectory(dataset.Data());
273
274 size_t iPlot = 0;
275 auto classnames = getclassnames(dataset, filename_input);
276 TString methodPrefix = "MVA_";
277 TString graphNameRef = "_rejBvsS_";
278
279 classcanvasmap_t classCanvasMap;
280 for (auto &classname : classnames) {
281 TString name = Form("roc_%s_vs_rest", classname.Data());
282 TString title = Form("ROC Curve %s vs rest", classname.Data());
283 EfficiencyPlotWrapper *plotWrapper = new EfficiencyPlotWrapper(name, title, dataset, iPlot++);
284 classCanvasMap.emplace(classname.Data(), plotWrapper);
285 }
286
287 roccurvelist_t rocCurves = getRocCurves(binDir, methodPrefix, graphNameRef);
288 plotEfficienciesMulticlass(rocCurves, classCanvasMap);
289
290 for (auto const &item : classCanvasMap) {
291 auto plotWrapper = item.second;
292 plotWrapper->save();
293 }
294}
295
296////////////////////////////////////////////////////////////////////////////////
297/// Entry point. Called from the TMVAMulticlassGui Buttons
298///
299/// \param dataset
300/// \param fin
301
303{
304 std::cout << "--- Running Roc1v1Gui for input file: " << fin << std::endl;
305
307
308 // create the control bar
309 TString title = "1v1 ROC curve comparison";
310 TControlBar *cbar = new TControlBar("vertical", title, 50, 50);
311
312 gDirectory->pwd();
313 auto classnames = getclassnames(dataset, fin);
314
315 // configure buttons
316 for (auto &classname : classnames) {
317 cbar->AddButton(Form("Class: %s", classname.Data()),
318 Form("TMVA::plotEfficienciesMulticlass1vs1(\"%s\", \"%s\", \"%s\")", dataset.Data(), fin.Data(),
319 classname.Data()),
321 }
322
323 cbar->SetTextColor("blue");
324 cbar->Show();
325
326 gROOT->SaveContext();
327}
328
329////////////////////////////////////////////////////////////////////////////////
330/// Generates K-1 plots comparing a given base class against all others (except
331/// itself). For each plot, the base class is considered signal and the other
332/// class is considered background.
333///
334/// Given 3 classes in the dataset and providing "Class 0" as the base class
335/// this would generate 2 plots comparing
336/// - Class 0 vs Class 1, and
337/// - Class 0 vs Class 2.
338/// For the "Class 0 vs Class 1" plot, events from Class 2 are ignored. For the
339/// "Class 0 vs Class 2" plot, events from Class 1 are ignored.
340///
341/// \param dataset
342/// \param fin
343/// \param baseClassname name of the class which will be considered signal
344
345void TMVA::plotEfficienciesMulticlass1vs1(TString dataset, TString fin, TString baseClassname)
346{
347
349
350 auto classnames = getclassnames(dataset, fin);
351 size_t iPlot = 0;
352
353 TString methodPrefix = "MVA_";
354 TString graphNameRef = Form("_1v1rejBvsS_%s_vs_", baseClassname.Data());
355
357 if (file == nullptr) {
358 std::cout << "ERROR: filename \"" << fin << "\" is not found.";
359 return;
360 }
361 auto binDir = file->GetDirectory(dataset.Data());
362
363 classcanvasmap_t classCanvasMap;
364 for (auto &classname : classnames) {
365
366 if (classname == baseClassname) {
367 continue;
368 }
369
370 TString name = Form("1v1roc_%s_vs_%s", baseClassname.Data(), classname.Data());
371 TString title = Form("ROC Curve %s (Sig) vs %s (Bkg)", baseClassname.Data(), classname.Data());
372 EfficiencyPlotWrapper *plotWrapper = new EfficiencyPlotWrapper(name, title, dataset, iPlot++);
373 classCanvasMap.emplace(classname.Data(), plotWrapper);
374 }
375
376 roccurvelist_t rocCurves = getRocCurves(binDir, methodPrefix, graphNameRef);
377 plotEfficienciesMulticlass(rocCurves, classCanvasMap);
378
379 for (auto const &item : classCanvasMap) {
380 auto plotWrapper = item.second;
381 plotWrapper->save();
382 }
383}
384
385////////////////////////////////////////////////////////////////////////////////
386/// Private class EfficiencyPlotWrapper - Implementation
387////////////////////////////////////////////////////////////////////////////////
388
389////////////////////////////////////////////////////////////////////////////////
390/// Constructs a new canvas + auxiliary data for showing an efficiency plot.
391///
392
393EfficiencyPlotWrapper::EfficiencyPlotWrapper(TString name, TString title, TString dataset, size_t i)
394{
395 // Legend extents (init before calling newEfficiencyLegend...)
396 fx0L = 0.107;
397 fy0H = 0.899;
398 fdxL = 0.457 - fx0L;
399 fdyH = 0.22;
400 fx0L = 0.15;
401 fy0H = 1 - fy0H + fdyH + 0.07;
402
403 fColor = 1;
404 fNumMethods = 0;
405
406 fDataset = dataset;
407
408 fCanvas = newEfficiencyCanvas(name, title, i);
409 fLegend = newEfficiencyLegend();
410}
411
412////////////////////////////////////////////////////////////////////////////////
413/// Adds a new graph to the plot. The added graph should contain a single ROC
414/// curve.
415///
416
417Int_t EfficiencyPlotWrapper::addGraph(TGraph *graph)
418{
419 graph->SetLineWidth(3);
420 graph->SetLineColor(fColor);
421 fColor++;
422 if (fColor == 5 || fColor == 10 || fColor == 11) {
423 fColor++;
424 }
425
426 fCanvas->cd();
427 graph->DrawClone("");
428 fCanvas->Update();
429
430 ++fNumMethods;
431
432 return fColor;
433}
434
435////////////////////////////////////////////////////////////////////////////////
436/// WARNING: Uses the current color, thus the correct call ordering is:
437/// plotWrapper->addGraph(...);
438/// plotWrapper->addLegendEntry(...);
439///
440
441void EfficiencyPlotWrapper::addLegendEntry(TString methodTitle, TGraph *graph)
442{
443 fLegend->AddEntry(graph, methodTitle, "l");
444
445 Float_t dyH_local = fdyH * (Float_t(TMath::Min((UInt_t)10, fNumMethods) - 3.0) / 4.0);
446 fLegend->SetY2(fy0H + dyH_local);
447
448 fLegend->Paint();
449 fCanvas->Update();
450}
451
452////////////////////////////////////////////////////////////////////////////////
453/// Helper to create new Canvas
454///
455/// \param name Name...
456/// \param title Title to be displayed on canvas
457/// \param i Index to offset a collection of canvases from each other
458///
459
460TCanvas *EfficiencyPlotWrapper::newEfficiencyCanvas(TString name, TString title, size_t i)
461{
462 TCanvas *c = new TCanvas(name, title, 200 + i * 50, 0 + i * 50, 650, 500);
463 // global style settings
464 c->SetGrid();
465 c->SetTicks();
466
467 // Frame
468 TString xtit = "Signal Efficiency";
469 TString ytit = "Background Rejection (1 - eff)";
470 Double_t x1 = 0.0;
471 Double_t x2 = 1.0;
472 Double_t y1 = 0.0;
473 Double_t y2 = 1.0;
474
475 TH2F *frame = new TH2F(Form("%s_%s", title.Data(), "frame"), title, 500, x1, x2, 500, y1, y2);
476 frame->GetXaxis()->SetTitle(xtit);
477 frame->GetYaxis()->SetTitle(ytit);
479 frame->DrawClone();
480
481 return c;
482}
483
484////////////////////////////////////////////////////////////////////////////////
485/// Helper to create new legend.
486
487TLegend *EfficiencyPlotWrapper::newEfficiencyLegend()
488{
489 TLegend *legend = new TLegend(fx0L, fy0H - fdyH, fx0L + fdxL, fy0H);
490 // legend->SetTextSize( 0.05 );
491 legend->SetHeader("MVA Method:");
492 legend->SetMargin(0.4);
493 legend->Draw("");
494
495 return legend;
496}
497
498////////////////////////////////////////////////////////////////////////////////
499/// Saves the current state of the plot to disk.
500///
501
502void EfficiencyPlotWrapper::save()
503{
504 TString fname = fDataset + "/plots/" + fCanvas->GetName();
505 TMVA::TMVAGlob::imgconv(fCanvas, fname);
506}
#define c(i)
Definition: RSha256.hxx:101
#define h(i)
Definition: RSha256.hxx:106
static const double x2[5]
static const double x1[5]
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
#define gDirectory
Definition: TDirectory.h:223
char name[80]
Definition: TGX11.cxx:109
#define gROOT
Definition: TROOT.h:415
char * Form(const char *fmt,...)
The Canvas class.
Definition: TCanvas.h:31
A Control Bar is a fully user configurable tool which provides fast access to frequently used operati...
Definition: TControlBar.h:22
void Show()
Show control bar.
void AddButton(TControlBarButton *button)
Add button.
void SetTextColor(const char *colorName)
Sets text color for control bar buttons, e.g.
Describe directory structure in memory.
Definition: TDirectory.h:34
virtual TList * GetListOfKeys() const
Definition: TDirectory.h:160
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:316
TAxis * GetYaxis()
Definition: TH1.h:317
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:251
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition: TKey.h:24
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition: TKey.cxx:729
This class displays a legend box (TPaveText) containing several legend entries.
Definition: TLegend.h:23
virtual void SetHeader(const char *header="", Option_t *option="")
Sets the header, which is the "title" that appears at the top of the legend.
Definition: TLegend.cxx:1099
virtual void Draw(Option_t *option="")
Draw this legend with its current attributes.
Definition: TLegend.cxx:423
void SetMargin(Float_t margin)
Definition: TLegend.h:69
A doubly linked list.
Definition: TList.h:44
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad for instance with: gROOT->SetSelectedPad(gPad...
Definition: TObject.cxx:219
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
const char * Data() const
Definition: TString.h:364
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:892
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:610
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
std::vector< std::tuple< TString, TString, TGraph * > > roccurvelist_t
const char * BUTTON_TYPE
std::map< TString, EfficiencyPlotWrapper * > classcanvasmap_t
static constexpr double nm
UInt_t GetListOfTitles(TDirectory *rfdir, TList &titles)
Definition: tmvaglob.cxx:636
void Initialize(Bool_t useTMVAStyle=kTRUE)
Definition: tmvaglob.cxx:176
TKey * NextKey(TIter &keyIter, TString className)
Definition: tmvaglob.cxx:357
void GetMethodTitle(TString &name, TKey *ikey)
Definition: tmvaglob.cxx:341
TFile * OpenFile(const TString &fin)
Definition: tmvaglob.cxx:192
void SetFrameStyle(TH1 *frame, Float_t scale=1.0)
Definition: tmvaglob.cxx:77
UInt_t GetListOfMethods(TList &methods, TDirectory *dir=0)
Definition: tmvaglob.cxx:583
std::vector< TString > GetClassNames(TDirectory *dir)
Definition: tmvaglob.cxx:462
void imgconv(TCanvas *c, const TString &fname)
Definition: tmvaglob.cxx:212
create variable transformations
roccurvelist_t getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef)
void efficienciesMulticlass1vs1(TString dataset, TString fin)
std::vector< TString > getclassnames(TString dataset, TString fin)
void plotEfficienciesMulticlass1vs1(TString dataset, TString fin, TString baseClassname)
void plotEfficienciesMulticlass1vsRest(TString dataset, EEfficiencyPlotType plotType=EEfficiencyPlotType::kRejBvsEffS, TString filename_input="TMVAMulticlass.root")
void plotEfficienciesMulticlass(roccurvelist_t rocCurves, classcanvasmap_t classCanvasMap)
void efficienciesMulticlass1vsRest(TString dataset, TString filename_input="TMVAMulticlass.root", EEfficiencyPlotType plotType=EEfficiencyPlotType::kRejBvsEffS, Bool_t useTMVAStyle=kTRUE)
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:180
Definition: file.py:1
Definition: graph.py:1