Logo ROOT  
Reference Guide
Factory.cxx
Go to the documentation of this file.
1// @(#)Root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3// Updated by: Omar Zapata, Kim Albertsson
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Factory *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors : *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
17 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
20 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
21 * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
22 * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
23 * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
24 * Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
25 * *
26 * Copyright (c) 2005-2015: *
27 * CERN, Switzerland *
28 * U. of Victoria, Canada *
29 * MPI-K Heidelberg, Germany *
30 * U. of Bonn, Germany *
31 * UdeA/ITM, Colombia *
32 * U. of Florida, USA *
33 * *
34 * Redistribution and use in source and binary forms, with or without *
35 * modification, are permitted according to the terms listed in LICENSE *
36 * (http://tmva.sourceforge.net/LICENSE) *
37 **********************************************************************************/
38
39/*! \class TMVA::Factory
40\ingroup TMVA
41
42This is the main MVA steering class.
43It creates all MVA methods, and guides them through the training, testing and
44evaluation phases.
45*/
46
47#include "TMVA/Factory.h"
48
50#include "TMVA/Config.h"
51#include "TMVA/Configurable.h"
52#include "TMVA/Tools.h"
53#include "TMVA/Ranking.h"
54#include "TMVA/DataSet.h"
55#include "TMVA/IMethod.h"
56#include "TMVA/MethodBase.h"
58#include "TMVA/DataSetManager.h"
59#include "TMVA/DataSetInfo.h"
60#include "TMVA/DataLoader.h"
61#include "TMVA/MethodBoost.h"
62#include "TMVA/MethodCategory.h"
63#include "TMVA/ROCCalc.h"
64#include "TMVA/ROCCurve.h"
65#include "TMVA/MsgLogger.h"
66
67#include "TMVA/VariableInfo.h"
69
70#include "TMVA/Results.h"
74#include <list>
75#include <bitset>
76#include <set>
77
78#include "TMVA/Types.h"
79
80#include "TROOT.h"
81#include "TFile.h"
82#include "TTree.h"
83#include "TLeaf.h"
84#include "TEventList.h"
85#include "TH2.h"
86#include "TText.h"
87#include "TLegend.h"
88#include "TGraph.h"
89#include "TStyle.h"
90#include "TMatrixF.h"
91#include "TMatrixDSym.h"
92#include "TMultiGraph.h"
93#include "TPrincipal.h"
94#include "TMath.h"
95#include "TSystem.h"
96#include "TCanvas.h"
97
99//const Int_t MinNoTestEvents = 1;
100
102
103#define READXML kTRUE
104
105//number of bits for bitset
106#define VIBITS 32
107
108
109
110////////////////////////////////////////////////////////////////////////////////
111/// Standard constructor.
112///
113/// - jobname : this name will appear in all weight file names produced by the MVAs
114/// - theTargetFile : output ROOT file; the test tree and all evaluation plots
115/// will be stored here
116/// - theOption : option string; currently: "V" for verbose
117
118TMVA::Factory::Factory( TString jobName, TFile* theTargetFile, TString theOption )
119: Configurable ( theOption ),
120 fTransformations ( "I" ),
121 fVerbose ( kFALSE ),
122 fVerboseLevel ( kINFO ),
123 fCorrelations ( kFALSE ),
124 fROC ( kTRUE ),
125 fSilentFile ( theTargetFile == nullptr ),
126 fJobName ( jobName ),
127 fAnalysisType ( Types::kClassification ),
128 fModelPersistence (kTRUE)
129{
130 fName = "Factory";
131 fgTargetFile = theTargetFile;
133
134 // render silent
135 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput(); // make sure is silent if wanted to
136
137
138 // init configurable
139 SetConfigDescription( "Configuration options for Factory running" );
141
142 // histograms are not automatically associated with the current
143 // directory and hence don't go out of scope when closing the file
144 // TH1::AddDirectory(kFALSE);
145 Bool_t silent = kFALSE;
146#ifdef WIN32
147 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these (would need different sequences..)
148 Bool_t color = kFALSE;
149 Bool_t drawProgressBar = kFALSE;
150#else
151 Bool_t color = !gROOT->IsBatch();
152 Bool_t drawProgressBar = kTRUE;
153#endif
154 DeclareOptionRef( fVerbose, "V", "Verbose flag" );
155 DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
156 AddPreDefVal(TString("Debug"));
157 AddPreDefVal(TString("Verbose"));
158 AddPreDefVal(TString("Info"));
159 DeclareOptionRef( color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)" );
160 DeclareOptionRef( fTransformations, "Transformations", "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
161 DeclareOptionRef( fCorrelations, "Correlations", "boolean to show correlation in output" );
162 DeclareOptionRef( fROC, "ROC", "boolean to show ROC in output" );
163 DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
164 DeclareOptionRef( drawProgressBar,
165 "DrawProgressBar", "Draw progress bar to display training, testing and evaluation schedule (default: True)" );
167 "ModelPersistence",
168 "Option to save the trained model in xml file or using serialization");
169
170 TString analysisType("Auto");
171 DeclareOptionRef( analysisType,
172 "AnalysisType", "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
173 AddPreDefVal(TString("Classification"));
174 AddPreDefVal(TString("Regression"));
175 AddPreDefVal(TString("Multiclass"));
176 AddPreDefVal(TString("Auto"));
177
178 ParseOptions();
180
181 if (Verbose()) fLogger->SetMinType( kVERBOSE );
182 if (fVerboseLevel.CompareTo("Debug") ==0) fLogger->SetMinType( kDEBUG );
183 if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
184 if (fVerboseLevel.CompareTo("Info") ==0) fLogger->SetMinType( kINFO );
185
186 // global settings
187 gConfig().SetUseColor( color );
188 gConfig().SetSilent( silent );
189 gConfig().SetDrawProgressBar( drawProgressBar );
190
191 analysisType.ToLower();
192 if ( analysisType == "classification" ) fAnalysisType = Types::kClassification;
193 else if( analysisType == "regression" ) fAnalysisType = Types::kRegression;
194 else if( analysisType == "multiclass" ) fAnalysisType = Types::kMulticlass;
195 else if( analysisType == "auto" ) fAnalysisType = Types::kNoAnalysisType;
196
197// Greetings();
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// Constructor.
202
204: Configurable ( theOption ),
205 fTransformations ( "I" ),
206 fVerbose ( kFALSE ),
207 fCorrelations ( kFALSE ),
208 fROC ( kTRUE ),
209 fSilentFile ( kTRUE ),
210 fJobName ( jobName ),
211 fAnalysisType ( Types::kClassification ),
212 fModelPersistence (kTRUE)
213{
214 fName = "Factory";
215 fgTargetFile = nullptr;
217
218
219 // render silent
220 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput(); // make sure is silent if wanted to
221
222
223 // init configurable
224 SetConfigDescription( "Configuration options for Factory running" );
226
227 // histograms are not automatically associated with the current
228 // directory and hence don't go out of scope when closing the file
230 Bool_t silent = kFALSE;
231#ifdef WIN32
232 // under Windows, switch progress bar and color off by default, as the typical windows shell doesn't handle these (would need different sequences..)
233 Bool_t color = kFALSE;
234 Bool_t drawProgressBar = kFALSE;
235#else
236 Bool_t color = !gROOT->IsBatch();
237 Bool_t drawProgressBar = kTRUE;
238#endif
239 DeclareOptionRef( fVerbose, "V", "Verbose flag" );
240 DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
241 AddPreDefVal(TString("Debug"));
242 AddPreDefVal(TString("Verbose"));
243 AddPreDefVal(TString("Info"));
244 DeclareOptionRef( color, "Color", "Flag for coloured screen output (default: True, if in batch mode: False)" );
245 DeclareOptionRef( fTransformations, "Transformations", "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
246 DeclareOptionRef( fCorrelations, "Correlations", "boolean to show correlation in output" );
247 DeclareOptionRef( fROC, "ROC", "boolean to show ROC in output" );
248 DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
249 DeclareOptionRef( drawProgressBar,
250 "DrawProgressBar", "Draw progress bar to display training, testing and evaluation schedule (default: True)" );
252 "ModelPersistence",
253 "Option to save the trained model in xml file or using serialization");
254
255 TString analysisType("Auto");
256 DeclareOptionRef( analysisType,
257 "AnalysisType", "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
258 AddPreDefVal(TString("Classification"));
259 AddPreDefVal(TString("Regression"));
260 AddPreDefVal(TString("Multiclass"));
261 AddPreDefVal(TString("Auto"));
262
263 ParseOptions();
265
266 if (Verbose()) fLogger->SetMinType( kVERBOSE );
267 if (fVerboseLevel.CompareTo("Debug") ==0) fLogger->SetMinType( kDEBUG );
268 if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
269 if (fVerboseLevel.CompareTo("Info") ==0) fLogger->SetMinType( kINFO );
270
271 // global settings
272 gConfig().SetUseColor( color );
273 gConfig().SetSilent( silent );
274 gConfig().SetDrawProgressBar( drawProgressBar );
275
276 analysisType.ToLower();
277 if ( analysisType == "classification" ) fAnalysisType = Types::kClassification;
278 else if( analysisType == "regression" ) fAnalysisType = Types::kRegression;
279 else if( analysisType == "multiclass" ) fAnalysisType = Types::kMulticlass;
280 else if( analysisType == "auto" ) fAnalysisType = Types::kNoAnalysisType;
281
282 Greetings();
283}
284
285////////////////////////////////////////////////////////////////////////////////
286/// Print welcome message.
287/// Options are: kLogoWelcomeMsg, kIsometricWelcomeMsg, kLeanWelcomeMsg
288
290{
292 gTools().TMVAWelcomeMessage( Log(), gTools().kLogoWelcomeMsg );
293 gTools().TMVAVersionMessage( Log() ); Log() << Endl;
294}
295
296////////////////////////////////////////////////////////////////////////////////
297/// Destructor.
298
300{
301 std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
302 for (;trfIt != fDefaultTrfs.end(); ++trfIt) delete (*trfIt);
303
304 this->DeleteAllMethods();
305
306
307 // problem with call of REGISTER_METHOD macro ...
308 // ClassifierFactory::DestroyInstance();
309 // Types::DestroyInstance();
310 //Tools::DestroyInstance();
311 //Config::DestroyInstance();
312}
313
314////////////////////////////////////////////////////////////////////////////////
315/// Delete methods.
316
318{
319 std::map<TString,MVector*>::iterator itrMap;
320
321 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
322 {
323 MVector *methods=itrMap->second;
324 // delete methods
325 MVector::iterator itrMethod = methods->begin();
326 for (; itrMethod != methods->end(); ++itrMethod) {
327 Log() << kDEBUG << "Delete method: " << (*itrMethod)->GetName() << Endl;
328 delete (*itrMethod);
329 }
330 methods->clear();
331 delete methods;
332 }
333}
334
335////////////////////////////////////////////////////////////////////////////////
336
338{
339 fVerbose = v;
340}
341
342////////////////////////////////////////////////////////////////////////////////
343/// Book a classifier or regression method.
344
345TMVA::MethodBase* TMVA::Factory::BookMethod( TMVA::DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption )
346{
347 if(fModelPersistence) gSystem->MakeDirectory(loader->GetName());//creating directory for DataLoader output
348
349 TString datasetname=loader->GetName();
350
351 if( fAnalysisType == Types::kNoAnalysisType ){
352 if( loader->GetDataSetInfo().GetNClasses()==2
353 && loader->GetDataSetInfo().GetClassInfo("Signal") != NULL
354 && loader->GetDataSetInfo().GetClassInfo("Background") != NULL
355 ){
356 fAnalysisType = Types::kClassification; // default is classification
357 } else if( loader->GetDataSetInfo().GetNClasses() >= 2 ){
358 fAnalysisType = Types::kMulticlass; // if two classes, but not named "Signal" and "Background"
359 } else
360 Log() << kFATAL << "No analysis type for " << loader->GetDataSetInfo().GetNClasses() << " classes and "
361 << loader->GetDataSetInfo().GetNTargets() << " regression targets." << Endl;
362 }
363
364 // booking via name; the names are translated into enums and the
365 // corresponding overloaded BookMethod is called
366
367 if(fMethodsMap.find(datasetname)!=fMethodsMap.end())
368 {
369 if (GetMethod( datasetname,methodTitle ) != 0) {
370 Log() << kFATAL << "Booking failed since method with title <"
371 << methodTitle <<"> already exists "<< "in with DataSet Name <"<< loader->GetName()<<"> "
372 << Endl;
373 }
374 }
375
376
377 Log() << kHEADER << "Booking method: " << gTools().Color("bold") << methodTitle
378 // << gTools().Color("reset")<<" DataSet Name: "<<gTools().Color("bold")<<loader->GetName()
379 << gTools().Color("reset") << Endl << Endl;
380
381 // interpret option string with respect to a request for boosting (i.e., BostNum > 0)
382 Int_t boostNum = 0;
383 TMVA::Configurable* conf = new TMVA::Configurable( theOption );
384 conf->DeclareOptionRef( boostNum = 0, "Boost_num",
385 "Number of times the classifier will be boosted" );
386 conf->ParseOptions();
387 delete conf;
388 // this is name of weight file directory
389 TString fileDir;
390 if(fModelPersistence)
391 {
392 // find prefix in fWeightFileDir;
394 fileDir = prefix;
395 if (!prefix.IsNull())
396 if (fileDir[fileDir.Length()-1] != '/') fileDir += "/";
397 fileDir += loader->GetName();
398 fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
399 }
400 // initialize methods
401 IMethod* im;
402 if (!boostNum) {
403 im = ClassifierFactory::Instance().Create(theMethodName.Data(), fJobName, methodTitle,
404 loader->GetDataSetInfo(), theOption);
405 }
406 else {
407 // boosted classifier, requires a specific definition, making it transparent for the user
408 Log() << kDEBUG <<"Boost Number is " << boostNum << " > 0: train boosted classifier" << Endl;
409 im = ClassifierFactory::Instance().Create("Boost", fJobName, methodTitle, loader->GetDataSetInfo(), theOption);
410 MethodBoost *methBoost = dynamic_cast<MethodBoost *>(im); // DSMTEST divided into two lines
411 if (!methBoost) { // DSMTEST
412 Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
413 return nullptr;
414 }
415 if (fModelPersistence) methBoost->SetWeightFileDir(fileDir);
416 methBoost->SetModelPersistence(fModelPersistence);
417 methBoost->SetBoostedMethodName(theMethodName); // DSMTEST divided into two lines
418 methBoost->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
419 methBoost->SetFile(fgTargetFile);
420 methBoost->SetSilentFile(IsSilentFile());
421 }
422
423 MethodBase *method = dynamic_cast<MethodBase*>(im);
424 if (method==0) return 0; // could not create method
425
426 // set fDataSetManager if MethodCategory (to enable Category to create datasetinfo objects) // DSMTEST
427 if (method->GetMethodType() == Types::kCategory) { // DSMTEST
428 MethodCategory *methCat = (dynamic_cast<MethodCategory*>(im)); // DSMTEST
429 if (!methCat) {// DSMTEST
430 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST
431 return nullptr;
432 }
433 if(fModelPersistence) methCat->SetWeightFileDir(fileDir);
434 methCat->SetModelPersistence(fModelPersistence);
435 methCat->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
436 methCat->SetFile(fgTargetFile);
437 methCat->SetSilentFile(IsSilentFile());
438 } // DSMTEST
439
440
441 if (!method->HasAnalysisType( fAnalysisType,
442 loader->GetDataSetInfo().GetNClasses(),
443 loader->GetDataSetInfo().GetNTargets() )) {
444 Log() << kWARNING << "Method " << method->GetMethodTypeName() << " is not capable of handling " ;
445 if (fAnalysisType == Types::kRegression) {
446 Log() << "regression with " << loader->GetDataSetInfo().GetNTargets() << " targets." << Endl;
447 }
448 else if (fAnalysisType == Types::kMulticlass ) {
449 Log() << "multiclass classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
450 }
451 else {
452 Log() << "classification with " << loader->GetDataSetInfo().GetNClasses() << " classes." << Endl;
453 }
454 return 0;
455 }
456
457 if(fModelPersistence) method->SetWeightFileDir(fileDir);
458 method->SetModelPersistence(fModelPersistence);
459 method->SetAnalysisType( fAnalysisType );
460 method->SetupMethod();
461 method->ParseOptions();
462 method->ProcessSetup();
463 method->SetFile(fgTargetFile);
464 method->SetSilentFile(IsSilentFile());
465
466 // check-for-unused-options is performed; may be overridden by derived classes
467 method->CheckSetup();
468
469 if(fMethodsMap.find(datasetname)==fMethodsMap.end())
470 {
471 MVector *mvector=new MVector;
472 fMethodsMap[datasetname]=mvector;
473 }
474 fMethodsMap[datasetname]->push_back( method );
475 return method;
476}
477
478////////////////////////////////////////////////////////////////////////////////
479/// Books MVA method. The option configuration string is custom for each MVA
480/// the TString field "theNameAppendix" serves to define (and distinguish)
481/// several instances of a given MVA, eg, when one wants to compare the
482/// performance of various configurations
483
485{
486 return BookMethod(loader, Types::Instance().GetMethodName( theMethod ), methodTitle, theOption );
487}
488
489////////////////////////////////////////////////////////////////////////////////
490/// Adds an already constructed method to be managed by this factory.
491///
492/// \note Private.
493/// \note Know what you are doing when using this method. The method that you
494/// are loading could be trained already.
495///
496
498{
499 TString datasetname = loader->GetName();
500 std::string methodTypeName = std::string(Types::Instance().GetMethodName(methodType).Data());
501 DataSetInfo &dsi = loader->GetDataSetInfo();
502
503 IMethod *im = ClassifierFactory::Instance().Create(methodTypeName, dsi, weightfile );
504 MethodBase *method = (dynamic_cast<MethodBase*>(im));
505
506 if (method == nullptr) return nullptr;
507
508 if( method->GetMethodType() == Types::kCategory ){
509 Log() << kERROR << "Cannot handle category methods for now." << Endl;
510 }
511
512 TString fileDir;
513 if(fModelPersistence) {
514 // find prefix in fWeightFileDir;
516 fileDir = prefix;
517 if (!prefix.IsNull())
518 if (fileDir[fileDir.Length() - 1] != '/')
519 fileDir += "/";
520 fileDir=loader->GetName();
521 fileDir+="/"+gConfig().GetIONames().fWeightFileDir;
522 }
523
524 if(fModelPersistence) method->SetWeightFileDir(fileDir);
525 method->SetModelPersistence(fModelPersistence);
526 method->SetAnalysisType( fAnalysisType );
527 method->SetupMethod();
528 method->SetFile(fgTargetFile);
529 method->SetSilentFile(IsSilentFile());
530
532
533 // read weight file
534 method->ReadStateFromFile();
535
536 //method->CheckSetup();
537
538 TString methodTitle = method->GetName();
539 if (HasMethod(datasetname, methodTitle) != 0) {
540 Log() << kFATAL << "Booking failed since method with title <"
541 << methodTitle <<"> already exists "<< "in with DataSet Name <"<< loader->GetName()<<"> "
542 << Endl;
543 }
544
545 Log() << kINFO << "Booked classifier \"" << method->GetMethodName()
546 << "\" of type: \"" << method->GetMethodTypeName() << "\"" << Endl;
547
548 if(fMethodsMap.count(datasetname) == 0) {
549 MVector *mvector = new MVector;
550 fMethodsMap[datasetname] = mvector;
551 }
552
553 fMethodsMap[datasetname]->push_back( method );
554
555 return method;
556}
557
558////////////////////////////////////////////////////////////////////////////////
559/// Returns pointer to MVA that corresponds to given method title.
560
561TMVA::IMethod* TMVA::Factory::GetMethod(const TString& datasetname, const TString &methodTitle ) const
562{
563 if(fMethodsMap.find(datasetname)==fMethodsMap.end()) return 0;
564
565 MVector *methods=fMethodsMap.find(datasetname)->second;
566
567 MVector::const_iterator itrMethod;
568 //
569 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
570 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
571 if ( (mva->GetMethodName())==methodTitle ) return mva;
572 }
573 return 0;
574}
575
576////////////////////////////////////////////////////////////////////////////////
577/// Checks whether a given method name is defined for a given dataset.
578
579Bool_t TMVA::Factory::HasMethod(const TString& datasetname, const TString &methodTitle ) const
580{
581 if(fMethodsMap.find(datasetname)==fMethodsMap.end()) return 0;
582
583 std::string methodName = methodTitle.Data();
584 auto isEqualToMethodName = [&methodName](TMVA::IMethod * m) {
585 return ( 0 == methodName.compare( m->GetName() ) );
586 };
587
588 TMVA::Factory::MVector * methods = this->fMethodsMap.at(datasetname);
589 Bool_t isMethodNameExisting = std::any_of( methods->begin(), methods->end(), isEqualToMethodName);
590
591 return isMethodNameExisting;
592}
593
594////////////////////////////////////////////////////////////////////////////////
595
597{
598 RootBaseDir()->cd();
599
600 if(!RootBaseDir()->GetDirectory(fDataSetInfo.GetName())) RootBaseDir()->mkdir(fDataSetInfo.GetName());
601 else return; //loader is now in the output file, we dont need to save again
602
603 RootBaseDir()->cd(fDataSetInfo.GetName());
604 fDataSetInfo.GetDataSet(); // builds dataset (including calculation of correlation matrix)
605
606
607 // correlation matrix of the default DS
608 const TMatrixD* m(0);
609 const TH2* h(0);
610
611 if(fAnalysisType == Types::kMulticlass){
612 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses() ; cls++) {
613 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
614 h = fDataSetInfo.CreateCorrelationMatrixHist(m, TString("CorrelationMatrix")+fDataSetInfo.GetClassInfo(cls)->GetName(),
615 TString("Correlation Matrix (")+ fDataSetInfo.GetClassInfo(cls)->GetName() +TString(")"));
616 if (h!=0) {
617 h->Write();
618 delete h;
619 }
620 }
621 }
622 else{
623 m = fDataSetInfo.CorrelationMatrix( "Signal" );
624 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixS", "Correlation Matrix (signal)");
625 if (h!=0) {
626 h->Write();
627 delete h;
628 }
629
630 m = fDataSetInfo.CorrelationMatrix( "Background" );
631 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrixB", "Correlation Matrix (background)");
632 if (h!=0) {
633 h->Write();
634 delete h;
635 }
636
637 m = fDataSetInfo.CorrelationMatrix( "Regression" );
638 h = fDataSetInfo.CreateCorrelationMatrixHist(m, "CorrelationMatrix", "Correlation Matrix");
639 if (h!=0) {
640 h->Write();
641 delete h;
642 }
643 }
644
645 // some default transformations to evaluate
646 // NOTE: all transformations are destroyed after this test
647 TString processTrfs = "I"; //"I;N;D;P;U;G,D;"
648
649 // plus some user defined transformations
650 processTrfs = fTransformations;
651
652 // remove any trace of identity transform - if given (avoid to apply it twice)
653 std::vector<TMVA::TransformationHandler*> trfs;
654 TransformationHandler* identityTrHandler = 0;
655
656 std::vector<TString> trfsDef = gTools().SplitString(processTrfs,';');
657 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
658 for (; trfsDefIt!=trfsDef.end(); ++trfsDefIt) {
659 trfs.push_back(new TMVA::TransformationHandler(fDataSetInfo, "Factory"));
660 TString trfS = (*trfsDefIt);
661
662 //Log() << kINFO << Endl;
663 Log() << kDEBUG << "current transformation string: '" << trfS.Data() << "'" << Endl;
665 fDataSetInfo,
666 *(trfs.back()),
667 Log() );
668
669 if (trfS.BeginsWith('I')) identityTrHandler = trfs.back();
670 }
671
672 const std::vector<Event*>& inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
673
674 // apply all transformations
675 std::vector<TMVA::TransformationHandler*>::iterator trfIt = trfs.begin();
676
677 for (;trfIt != trfs.end(); ++trfIt) {
678 // setting a Root dir causes the variables distributions to be saved to the root file
679 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName()));// every dataloader have its own dir
680 (*trfIt)->CalcTransformations(inputEvents);
681 }
682 if(identityTrHandler) identityTrHandler->PrintVariableRanking();
683
684 // clean up
685 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt) delete *trfIt;
686}
687
688////////////////////////////////////////////////////////////////////////////////
689/// Iterates through all booked methods and sees if they use parameter tuning and if so..
690/// does just that i.e. calls "Method::Train()" for different parameter settings and
691/// keeps in mind the "optimal one"... and that's the one that will later on be used
692/// in the main training loop.
693
694std::map<TString,Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TString fitType)
695{
696
697 std::map<TString,MVector*>::iterator itrMap;
698 std::map<TString,Double_t> TunedParameters;
699 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
700 {
701 MVector *methods=itrMap->second;
702
703 MVector::iterator itrMethod;
704
705 // iterate over methods and optimize
706 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
708 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
709 if (!mva) {
710 Log() << kFATAL << "Dynamic cast to MethodBase failed" <<Endl;
711 return TunedParameters;
712 }
713
715 Log() << kWARNING << "Method " << mva->GetMethodName()
716 << " not trained (training tree has less entries ["
717 << mva->Data()->GetNTrainingEvents()
718 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
719 continue;
720 }
721
722 Log() << kINFO << "Optimize method: " << mva->GetMethodName() << " for "
723 << (fAnalysisType == Types::kRegression ? "Regression" :
724 (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl;
725
726 TunedParameters = mva->OptimizeTuningParameters(fomType,fitType);
727 Log() << kINFO << "Optimization of tuning parameters finished for Method:"<<mva->GetName() << Endl;
728 }
729 }
730
731 return TunedParameters;
732
733}
734
735////////////////////////////////////////////////////////////////////////////////
736/// Private method to generate a ROCCurve instance for a given method.
737/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
738/// understands.
739///
740/// \note You own the retured pointer.
741///
742
745{
746 return GetROC((TString)loader->GetName(), theMethodName, iClass, type);
747}
748
749////////////////////////////////////////////////////////////////////////////////
750/// Private method to generate a ROCCurve instance for a given method.
751/// Handles the conversion from TMVA ResultSet to a format the ROCCurve class
752/// understands.
753///
754/// \note You own the retured pointer.
755///
756
758{
759 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
760 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
761 return nullptr;
762 }
763
764 if (!this->HasMethod(datasetname, theMethodName)) {
765 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
766 << Endl;
767 return nullptr;
768 }
769
770 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
771 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
772 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.")
773 << Endl;
774 return nullptr;
775 }
776
777 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(this->GetMethod(datasetname, theMethodName));
778 TMVA::DataSet *dataset = method->Data();
779 dataset->SetCurrentType(type);
780 TMVA::Results *results = dataset->GetResults(theMethodName, type, this->fAnalysisType);
781
782 UInt_t nClasses = method->DataInfo().GetNClasses();
783 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
784 Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
785 iClass, nClasses)
786 << Endl;
787 return nullptr;
788 }
789
790 TMVA::ROCCurve *rocCurve = nullptr;
791 if (this->fAnalysisType == Types::kClassification) {
792
793 std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
794 std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
795 std::vector<Float_t> mvaResWeights;
796
797 auto eventCollection = dataset->GetEventCollection(type);
798 mvaResWeights.reserve(eventCollection.size());
799 for (auto ev : eventCollection) {
800 mvaResWeights.push_back(ev->GetWeight());
801 }
802
803 rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
804
805 } else if (this->fAnalysisType == Types::kMulticlass) {
806 std::vector<Float_t> mvaRes;
807 std::vector<Bool_t> mvaResTypes;
808 std::vector<Float_t> mvaResWeights;
809
810 std::vector<std::vector<Float_t>> *rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();
811
812 // Vector transpose due to values being stored as
813 // [ [0, 1, 2], [0, 1, 2], ... ]
814 // in ResultsMulticlass::GetValueVector.
815 mvaRes.reserve(rawMvaRes->size());
816 for (auto item : *rawMvaRes) {
817 mvaRes.push_back(item[iClass]);
818 }
819
820 auto eventCollection = dataset->GetEventCollection(type);
821 mvaResTypes.reserve(eventCollection.size());
822 mvaResWeights.reserve(eventCollection.size());
823 for (auto ev : eventCollection) {
824 mvaResTypes.push_back(ev->GetClass() == iClass);
825 mvaResWeights.push_back(ev->GetWeight());
826 }
827
828 rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
829 }
830
831 return rocCurve;
832}
833
834////////////////////////////////////////////////////////////////////////////////
835/// Calculate the integral of the ROC curve, also known as the area under curve
836/// (AUC), for a given method.
837///
838/// Argument iClass specifies the class to generate the ROC curve in a
839/// multiclass setting. It is ignored for binary classification.
840///
841
843{
844 return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass);
845}
846
847////////////////////////////////////////////////////////////////////////////////
848/// Calculate the integral of the ROC curve, also known as the area under curve
849/// (AUC), for a given method.
850///
851/// Argument iClass specifies the class to generate the ROC curve in a
852/// multiclass setting. It is ignored for binary classification.
853///
854
856{
857 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
858 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
859 return 0;
860 }
861
862 if ( ! this->HasMethod(datasetname, theMethodName) ) {
863 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
864 return 0;
865 }
866
867 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
868 if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
869 Log() << kERROR << Form("Can only generate ROC integral for analysis type kClassification. and kMulticlass.")
870 << Endl;
871 return 0;
872 }
873
874 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
875 if (!rocCurve) {
876 Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ",
877 theMethodName.Data(), datasetname.Data())
878 << Endl;
879 return 0;
880 }
881
883 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
884 delete rocCurve;
885
886 return rocIntegral;
887}
888
889////////////////////////////////////////////////////////////////////////////////
890/// Argument iClass specifies the class to generate the ROC curve in a
891/// multiclass setting. It is ignored for binary classification.
892///
893/// Returns a ROC graph for a given method, or nullptr on error.
894///
895/// Note: Evaluation of the given method must have been run prior to ROC
896/// generation through Factory::EvaluateAllMetods.
897///
898/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
899/// and the others considered background. This is ok in binary classification
900/// but in in multi class classification, the ROC surface is an N dimensional
901/// shape, where N is number of classes - 1.
902
903TGraph* TMVA::Factory::GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles, UInt_t iClass)
904{
905 return GetROCCurve( (TString)loader->GetName(), theMethodName, setTitles, iClass );
906}
907
908////////////////////////////////////////////////////////////////////////////////
909/// Argument iClass specifies the class to generate the ROC curve in a
910/// multiclass setting. It is ignored for binary classification.
911///
912/// Returns a ROC graph for a given method, or nullptr on error.
913///
914/// Note: Evaluation of the given method must have been run prior to ROC
915/// generation through Factory::EvaluateAllMetods.
916///
917/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
918/// and the others considered background. This is ok in binary classification
919/// but in in multi class classification, the ROC surface is an N dimensional
920/// shape, where N is number of classes - 1.
921
922TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles, UInt_t iClass)
923{
924 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
925 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
926 return nullptr;
927 }
928
929 if ( ! this->HasMethod(datasetname, theMethodName) ) {
930 Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
931 return nullptr;
932 }
933
934 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
935 if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
936 Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.") << Endl;
937 return nullptr;
938 }
939
940 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
941 TGraph *graph = nullptr;
942
943 if ( ! rocCurve ) {
944 Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
945 return nullptr;
946 }
947
948 graph = (TGraph *)rocCurve->GetROCCurve()->Clone();
949 delete rocCurve;
950
951 if(setTitles) {
952 graph->GetYaxis()->SetTitle("Background rejection (Specificity)");
953 graph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
954 graph->SetTitle(Form("Signal efficiency vs. Background rejection (%s)", theMethodName.Data()));
955 }
956
957 return graph;
958}
959
960////////////////////////////////////////////////////////////////////////////////
961/// Generate a collection of graphs, for all methods for a given class. Suitable
962/// for comparing method performance.
963///
964/// Argument iClass specifies the class to generate the ROC curve in a
965/// multiclass setting. It is ignored for binary classification.
966///
967/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
968/// and the others considered background. This is ok in binary classification
969/// but in in multi class classification, the ROC surface is an N dimensional
970/// shape, where N is number of classes - 1.
971
973{
974 return GetROCCurveAsMultiGraph((TString)loader->GetName(), iClass);
975}
976
977////////////////////////////////////////////////////////////////////////////////
978/// Generate a collection of graphs, for all methods for a given class. Suitable
979/// for comparing method performance.
980///
981/// Argument iClass specifies the class to generate the ROC curve in a
982/// multiclass setting. It is ignored for binary classification.
983///
984/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
985/// and the others considered background. This is ok in binary classification
986/// but in in multi class classification, the ROC surface is an N dimensional
987/// shape, where N is number of classes - 1.
988
990{
991 UInt_t line_color = 1;
992
993 TMultiGraph *multigraph = new TMultiGraph();
994
995 MVector *methods = fMethodsMap[datasetname.Data()];
996 for (auto * method_raw : *methods) {
997 TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>(method_raw);
998 if (method == nullptr) { continue; }
999
1000 TString methodName = method->GetMethodName();
1001 UInt_t nClasses = method->DataInfo().GetNClasses();
1002
1003 if ( this->fAnalysisType == Types::kMulticlass && iClass >= nClasses ) {
1004 Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass, nClasses) << Endl;
1005 continue;
1006 }
1007
1008 TString className = method->DataInfo().GetClassInfo(iClass)->GetName();
1009
1010 TGraph *graph = this->GetROCCurve(datasetname, methodName, false, iClass);
1011 graph->SetTitle(methodName);
1012
1013 graph->SetLineWidth(2);
1014 graph->SetLineColor(line_color++);
1015 graph->SetFillColor(10);
1016
1017 multigraph->Add(graph);
1018 }
1019
1020 if ( multigraph->GetListOfGraphs() == nullptr ) {
1021 Log() << kERROR << Form("No metohds have class %i defined.", iClass) << Endl;
1022 return nullptr;
1023 }
1024
1025 return multigraph;
1026}
1027
1028////////////////////////////////////////////////////////////////////////////////
1029/// Draws ROC curves for all methods booked with the factory for a given class
1030/// onto a canvas.
1031///
1032/// Argument iClass specifies the class to generate the ROC curve in a
1033/// multiclass setting. It is ignored for binary classification.
1034///
1035/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1036/// and the others considered background. This is ok in binary classification
1037/// but in in multi class classification, the ROC surface is an N dimensional
1038/// shape, where N is number of classes - 1.
1039
1041{
1042 return GetROCCurve((TString)loader->GetName(), iClass);
1043}
1044
1045////////////////////////////////////////////////////////////////////////////////
1046/// Draws ROC curves for all methods booked with the factory for a given class.
1047///
1048/// Argument iClass specifies the class to generate the ROC curve in a
1049/// multiclass setting. It is ignored for binary classification.
1050///
1051/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
1052/// and the others considered background. This is ok in binary classification
1053/// but in in multi class classification, the ROC surface is an N dimensional
1054/// shape, where N is number of classes - 1.
1055
1057{
1058 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
1059 Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
1060 return 0;
1061 }
1062
1063 TString name = Form("ROCCurve %s class %i", datasetname.Data(), iClass);
1064 TCanvas *canvas = new TCanvas(name, "ROC Curve", 200, 10, 700, 500);
1065 canvas->SetGrid();
1066
1067 TMultiGraph *multigraph = this->GetROCCurveAsMultiGraph(datasetname, iClass);
1068
1069 if ( multigraph ) {
1070 multigraph->Draw("AL");
1071
1072 multigraph->GetYaxis()->SetTitle("Background rejection (Specificity)");
1073 multigraph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
1074
1075 TString titleString = Form("Signal efficiency vs. Background rejection");
1076 if (this->fAnalysisType == Types::kMulticlass) {
1077 titleString = Form("%s (Class=%i)", titleString.Data(), iClass);
1078 }
1079
1080 // Workaround for TMultigraph not drawing title correctly.
1081 multigraph->GetHistogram()->SetTitle( titleString );
1082 multigraph->SetTitle( titleString );
1083
1084 canvas->BuildLegend(0.15, 0.15, 0.35, 0.3, "MVA Method");
1085 }
1086
1087 return canvas;
1088}
1089
1090////////////////////////////////////////////////////////////////////////////////
1091/// Iterates through all booked methods and calls training
1092
1094{
1095 Log() << kHEADER << gTools().Color("bold") << "Train all methods" << gTools().Color("reset") << Endl;
1096 // iterates over all MVAs that have been booked, and calls their training methods
1097
1098
1099 // don't do anything if no method booked
1100 if (fMethodsMap.empty()) {
1101 Log() << kINFO << "...nothing found to train" << Endl;
1102 return;
1103 }
1104
1105 // here the training starts
1106 //Log() << kINFO << " " << Endl;
1107 Log() << kDEBUG << "Train all methods for "
1108 << (fAnalysisType == Types::kRegression ? "Regression" :
1109 (fAnalysisType == Types::kMulticlass ? "Multiclass" : "Classification") ) << " ..." << Endl;
1110
1111 std::map<TString,MVector*>::iterator itrMap;
1112
1113 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1114 {
1115 MVector *methods=itrMap->second;
1116 MVector::iterator itrMethod;
1117
1118 // iterate over methods and train
1119 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
1121 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1122
1123 if(mva==0) continue;
1124
1125 if(mva->DataInfo().GetDataSetManager()->DataInput().GetEntries() <=1) { // 0 entries --> 0 events, 1 entry --> dynamical dataset (or one entry)
1126 Log() << kFATAL << "No input data for the training provided!" << Endl;
1127 }
1128
1129 if(fAnalysisType == Types::kRegression && mva->DataInfo().GetNTargets() < 1 )
1130 Log() << kFATAL << "You want to do regression training without specifying a target." << Endl;
1131 else if( (fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification)
1132 && mva->DataInfo().GetNClasses() < 2 )
1133 Log() << kFATAL << "You want to do classification training, but specified less than two classes." << Endl;
1134
1135 // first print some information about the default dataset
1136 if(!IsSilentFile()) WriteDataInformation(mva->fDataSetInfo);
1137
1138
1140 Log() << kWARNING << "Method " << mva->GetMethodName()
1141 << " not trained (training tree has less entries ["
1142 << mva->Data()->GetNTrainingEvents()
1143 << "] than required [" << MinNoTrainingEvents << "]" << Endl;
1144 continue;
1145 }
1146
1147 Log() << kHEADER << "Train method: " << mva->GetMethodName() << " for "
1148 << (fAnalysisType == Types::kRegression ? "Regression" :
1149 (fAnalysisType == Types::kMulticlass ? "Multiclass classification" : "Classification")) << Endl << Endl;
1150 mva->TrainMethod();
1151 Log() << kHEADER << "Training finished" << Endl << Endl;
1152 }
1153
1154 if (fAnalysisType != Types::kRegression) {
1155
1156 // variable ranking
1157 //Log() << Endl;
1158 Log() << kINFO << "Ranking input variables (method specific)..." << Endl;
1159 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1160 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
1161 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1162
1163 // create and print ranking
1164 const Ranking* ranking = (*itrMethod)->CreateRanking();
1165 if (ranking != 0) ranking->Print();
1166 else Log() << kINFO << "No variable ranking supplied by classifier: "
1167 << dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
1168 }
1169 }
1170 }
1171
1172 // save training history in case we are not in the silent mode
1173 if (!IsSilentFile()) {
1174 for (UInt_t i=0; i<methods->size(); i++) {
1175 MethodBase* m = dynamic_cast<MethodBase*>((*methods)[i]);
1176 if(m==0) continue;
1177 m->BaseDir()->cd();
1178 m->fTrainHistory.SaveHistory(m->GetMethodName());
1179 }
1180 }
1181
1182 // delete all methods and recreate them from weight file - this ensures that the application
1183 // of the methods (in TMVAClassificationApplication) is consistent with the results obtained
1184 // in the testing
1185 //Log() << Endl;
1186 if (fModelPersistence) {
1187
1188 Log() << kHEADER << "=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1189
1190 if(!IsSilentFile())RootBaseDir()->cd();
1191
1192 // iterate through all booked methods
1193 for (UInt_t i=0; i<methods->size(); i++) {
1194
1195 MethodBase *m = dynamic_cast<MethodBase *>((*methods)[i]);
1196 if (m == nullptr)
1197 continue;
1198
1199 TMVA::Types::EMVA methodType = m->GetMethodType();
1200 TString weightfile = m->GetWeightFileName();
1201
1202 // decide if .txt or .xml file should be read:
1203 if (READXML)
1204 weightfile.ReplaceAll(".txt", ".xml");
1205
1206 DataSetInfo &dataSetInfo = m->DataInfo();
1207 TString testvarName = m->GetTestvarName();
1208 delete m; // itrMethod[i];
1209
1210 // recreate
1211 m = dynamic_cast<MethodBase *>(ClassifierFactory::Instance().Create(
1212 Types::Instance().GetMethodName(methodType).Data(), dataSetInfo, weightfile));
1213 if (m->GetMethodType() == Types::kCategory) {
1214 MethodCategory *methCat = (dynamic_cast<MethodCategory *>(m));
1215 if (!methCat)
1216 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1217 else
1218 methCat->fDataSetManager = m->DataInfo().GetDataSetManager();
1219 }
1220 // ToDo, Do we need to fill the DataSetManager of MethodBoost here too?
1221
1222 TString wfileDir = m->DataInfo().GetName();
1223 wfileDir += "/" + gConfig().GetIONames().fWeightFileDir;
1224 m->SetWeightFileDir(wfileDir);
1225 m->SetModelPersistence(fModelPersistence);
1226 m->SetSilentFile(IsSilentFile());
1227 m->SetAnalysisType(fAnalysisType);
1228 m->SetupMethod();
1229 m->ReadStateFromFile();
1230 m->SetTestvarName(testvarName);
1231
1232 // replace trained method by newly created one (from weight file) in methods vector
1233 (*methods)[i] = m;
1234 }
1235 }
1236 }
1237}
1238
1239////////////////////////////////////////////////////////////////////////////////
1240/// Evaluates all booked methods on the testing data and adds the output to the
1241/// Results in the corresponiding DataSet.
1242///
1243
1245{
1246 Log() << kHEADER << gTools().Color("bold") << "Test all methods" << gTools().Color("reset") << Endl;
1247
1248 // don't do anything if no method booked
1249 if (fMethodsMap.empty()) {
1250 Log() << kINFO << "...nothing found to test" << Endl;
1251 return;
1252 }
1253 std::map<TString,MVector*>::iterator itrMap;
1254
1255 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1256 {
1257 MVector *methods=itrMap->second;
1258 MVector::iterator itrMethod;
1259
1260 // iterate over methods and test
1261 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1263 MethodBase *mva = dynamic_cast<MethodBase *>(*itrMethod);
1264 if (mva == 0)
1265 continue;
1266 Types::EAnalysisType analysisType = mva->GetAnalysisType();
1267 Log() << kHEADER << "Test method: " << mva->GetMethodName() << " for "
1268 << (analysisType == Types::kRegression
1269 ? "Regression"
1270 : (analysisType == Types::kMulticlass ? "Multiclass classification" : "Classification"))
1271 << " performance" << Endl << Endl;
1272 mva->AddOutput(Types::kTesting, analysisType);
1273 }
1274 }
1275}
1276
1277////////////////////////////////////////////////////////////////////////////////
1278
1279void TMVA::Factory::MakeClass(const TString& datasetname , const TString& methodTitle ) const
1280{
1281 if (methodTitle != "") {
1282 IMethod* method = GetMethod(datasetname, methodTitle);
1283 if (method) method->MakeClass();
1284 else {
1285 Log() << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle
1286 << "\" in list" << Endl;
1287 }
1288 }
1289 else {
1290
1291 // no classifier specified, print all help messages
1292 MVector *methods=fMethodsMap.find(datasetname)->second;
1293 MVector::const_iterator itrMethod;
1294 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1295 MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1296 if(method==0) continue;
1297 Log() << kINFO << "Make response class for classifier: " << method->GetMethodName() << Endl;
1298 method->MakeClass();
1299 }
1300 }
1301}
1302
1303////////////////////////////////////////////////////////////////////////////////
1304/// Print predefined help message of classifier.
1305/// Iterate over methods and test.
1306
1307void TMVA::Factory::PrintHelpMessage(const TString& datasetname , const TString& methodTitle ) const
1308{
1309 if (methodTitle != "") {
1310 IMethod* method = GetMethod(datasetname , methodTitle );
1311 if (method) method->PrintHelpMessage();
1312 else {
1313 Log() << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle
1314 << "\" in list" << Endl;
1315 }
1316 }
1317 else {
1318
1319 // no classifier specified, print all help messages
1320 MVector *methods=fMethodsMap.find(datasetname)->second;
1321 MVector::const_iterator itrMethod ;
1322 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1323 MethodBase* method = dynamic_cast<MethodBase*>(*itrMethod);
1324 if(method==0) continue;
1325 Log() << kINFO << "Print help message for classifier: " << method->GetMethodName() << Endl;
1326 method->PrintHelpMessage();
1327 }
1328 }
1329}
1330
1331////////////////////////////////////////////////////////////////////////////////
1332/// Iterates over all MVA input variables and evaluates them.
1333
1335{
1336 Log() << kINFO << "Evaluating all variables..." << Endl;
1338
1339 for (UInt_t i=0; i<loader->GetDataSetInfo().GetNVariables(); i++) {
1341 if (options.Contains("V")) s += ":V";
1342 this->BookMethod(loader, "Variable", s );
1343 }
1344}
1345
1346////////////////////////////////////////////////////////////////////////////////
1347/// Iterates over all MVAs that have been booked, and calls their evaluation methods.
1348
1350{
1351 Log() << kHEADER << gTools().Color("bold") << "Evaluate all methods" << gTools().Color("reset") << Endl;
1352
1353 // don't do anything if no method booked
1354 if (fMethodsMap.empty()) {
1355 Log() << kINFO << "...nothing found to evaluate" << Endl;
1356 return;
1357 }
1358 std::map<TString,MVector*>::iterator itrMap;
1359
1360 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1361 {
1362 MVector *methods=itrMap->second;
1363
1364 // -----------------------------------------------------------------------
1365 // First part of evaluation process
1366 // --> compute efficiencies, and other separation estimators
1367 // -----------------------------------------------------------------------
1368
1369 // although equal, we now want to separate the output for the variables
1370 // and the real methods
1371 Int_t isel; // will be 0 for a Method; 1 for a Variable
1372 Int_t nmeth_used[2] = {0,0}; // 0 Method; 1 Variable
1373
1374 std::vector<std::vector<TString> > mname(2);
1375 std::vector<std::vector<Double_t> > sig(2), sep(2), roc(2);
1376 std::vector<std::vector<Double_t> > eff01(2), eff10(2), eff30(2), effArea(2);
1377 std::vector<std::vector<Double_t> > eff01err(2), eff10err(2), eff30err(2);
1378 std::vector<std::vector<Double_t> > trainEff01(2), trainEff10(2), trainEff30(2);
1379
1380 std::vector<std::vector<Float_t> > multiclass_testEff;
1381 std::vector<std::vector<Float_t> > multiclass_trainEff;
1382 std::vector<std::vector<Float_t> > multiclass_testPur;
1383 std::vector<std::vector<Float_t> > multiclass_trainPur;
1384
1385 std::vector<std::vector<Float_t> > train_history;
1386
1387 // Multiclass confusion matrices.
1388 std::vector<TMatrixD> multiclass_trainConfusionEffB01;
1389 std::vector<TMatrixD> multiclass_trainConfusionEffB10;
1390 std::vector<TMatrixD> multiclass_trainConfusionEffB30;
1391 std::vector<TMatrixD> multiclass_testConfusionEffB01;
1392 std::vector<TMatrixD> multiclass_testConfusionEffB10;
1393 std::vector<TMatrixD> multiclass_testConfusionEffB30;
1394
1395 std::vector<std::vector<Double_t> > biastrain(1); // "bias" of the regression on the training data
1396 std::vector<std::vector<Double_t> > biastest(1); // "bias" of the regression on test data
1397 std::vector<std::vector<Double_t> > devtrain(1); // "dev" of the regression on the training data
1398 std::vector<std::vector<Double_t> > devtest(1); // "dev" of the regression on test data
1399 std::vector<std::vector<Double_t> > rmstrain(1); // "rms" of the regression on the training data
1400 std::vector<std::vector<Double_t> > rmstest(1); // "rms" of the regression on test data
1401 std::vector<std::vector<Double_t> > minftrain(1); // "minf" of the regression on the training data
1402 std::vector<std::vector<Double_t> > minftest(1); // "minf" of the regression on test data
1403 std::vector<std::vector<Double_t> > rhotrain(1); // correlation of the regression on the training data
1404 std::vector<std::vector<Double_t> > rhotest(1); // correlation of the regression on test data
1405
1406 // same as above but for 'truncated' quantities (computed for events within 2sigma of RMS)
1407 std::vector<std::vector<Double_t> > biastrainT(1);
1408 std::vector<std::vector<Double_t> > biastestT(1);
1409 std::vector<std::vector<Double_t> > devtrainT(1);
1410 std::vector<std::vector<Double_t> > devtestT(1);
1411 std::vector<std::vector<Double_t> > rmstrainT(1);
1412 std::vector<std::vector<Double_t> > rmstestT(1);
1413 std::vector<std::vector<Double_t> > minftrainT(1);
1414 std::vector<std::vector<Double_t> > minftestT(1);
1415
1416 // following vector contains all methods - with the exception of Cuts, which are special
1417 MVector methodsNoCuts;
1418
1419 Bool_t doRegression = kFALSE;
1420 Bool_t doMulticlass = kFALSE;
1421
1422 // iterate over methods and evaluate
1423 for (MVector::iterator itrMethod =methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1425 MethodBase* theMethod = dynamic_cast<MethodBase*>(*itrMethod);
1426 if(theMethod==0) continue;
1427 theMethod->SetFile(fgTargetFile);
1428 theMethod->SetSilentFile(IsSilentFile());
1429 if (theMethod->GetMethodType() != Types::kCuts) methodsNoCuts.push_back( *itrMethod );
1430
1431 if (theMethod->DoRegression()) {
1432 doRegression = kTRUE;
1433
1434 Log() << kINFO << "Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1435 Double_t bias, dev, rms, mInf;
1436 Double_t biasT, devT, rmsT, mInfT;
1437 Double_t rho;
1438
1439 Log() << kINFO << "TestRegression (testing)" << Endl;
1440 theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting );
1441 biastest[0] .push_back( bias );
1442 devtest[0] .push_back( dev );
1443 rmstest[0] .push_back( rms );
1444 minftest[0] .push_back( mInf );
1445 rhotest[0] .push_back( rho );
1446 biastestT[0] .push_back( biasT );
1447 devtestT[0] .push_back( devT );
1448 rmstestT[0] .push_back( rmsT );
1449 minftestT[0] .push_back( mInfT );
1450
1451 Log() << kINFO << "TestRegression (training)" << Endl;
1452 theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining );
1453 biastrain[0] .push_back( bias );
1454 devtrain[0] .push_back( dev );
1455 rmstrain[0] .push_back( rms );
1456 minftrain[0] .push_back( mInf );
1457 rhotrain[0] .push_back( rho );
1458 biastrainT[0].push_back( biasT );
1459 devtrainT[0] .push_back( devT );
1460 rmstrainT[0] .push_back( rmsT );
1461 minftrainT[0].push_back( mInfT );
1462
1463 mname[0].push_back( theMethod->GetMethodName() );
1464 nmeth_used[0]++;
1465 if (!IsSilentFile()) {
1466 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1469 }
1470 } else if (theMethod->DoMulticlass()) {
1471 // ====================================================================
1472 // === Multiclass evaluation
1473 // ====================================================================
1474 doMulticlass = kTRUE;
1475 Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1476
1477 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1478 // This is why it is disabled for now.
1479 // Find approximate optimal working point w.r.t. signalEfficiency * signalPurity.
1480 // theMethod->TestMulticlass(); // This is where the actual GA calc is done
1481 // multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));
1482
1483 theMethod->TestMulticlass();
1484
1485 // Confusion matrix at three background efficiency levels
1486 multiclass_trainConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTraining));
1487 multiclass_trainConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTraining));
1488 multiclass_trainConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTraining));
1489
1490 multiclass_testConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTesting));
1491 multiclass_testConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTesting));
1492 multiclass_testConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTesting));
1493
1494 if (!IsSilentFile()) {
1495 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1498 }
1499
1500 nmeth_used[0]++;
1501 mname[0].push_back(theMethod->GetMethodName());
1502 } else {
1503
1504 Log() << kHEADER << "Evaluate classifier: " << theMethod->GetMethodName() << Endl << Endl;
1505 isel = (theMethod->GetMethodTypeName().Contains("Variable")) ? 1 : 0;
1506
1507 // perform the evaluation
1508 theMethod->TestClassification();
1509
1510 // evaluate the classifier
1511 mname[isel].push_back(theMethod->GetMethodName());
1512 sig[isel].push_back(theMethod->GetSignificance());
1513 sep[isel].push_back(theMethod->GetSeparation());
1514 roc[isel].push_back(theMethod->GetROCIntegral());
1515
1516 Double_t err;
1517 eff01[isel].push_back(theMethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
1518 eff01err[isel].push_back(err);
1519 eff10[isel].push_back(theMethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
1520 eff10err[isel].push_back(err);
1521 eff30[isel].push_back(theMethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
1522 eff30err[isel].push_back(err);
1523 effArea[isel].push_back(theMethod->GetEfficiency("", Types::kTesting, err)); // computes the area (average)
1524
1525 trainEff01[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.01")); // the first pass takes longer
1526 trainEff10[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.10"));
1527 trainEff30[isel].push_back(theMethod->GetTrainingEfficiency("Efficiency:0.30"));
1528
1529 nmeth_used[isel]++;
1530
1531 if (!IsSilentFile()) {
1532 Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
1535 }
1536 }
1537 }
1538 if (doRegression) {
1539
1540 std::vector<TString> vtemps = mname[0];
1541 std::vector< std::vector<Double_t> > vtmp;
1542 vtmp.push_back( devtest[0] ); // this is the vector that is ranked
1543 vtmp.push_back( devtrain[0] );
1544 vtmp.push_back( biastest[0] );
1545 vtmp.push_back( biastrain[0] );
1546 vtmp.push_back( rmstest[0] );
1547 vtmp.push_back( rmstrain[0] );
1548 vtmp.push_back( minftest[0] );
1549 vtmp.push_back( minftrain[0] );
1550 vtmp.push_back( rhotest[0] );
1551 vtmp.push_back( rhotrain[0] );
1552 vtmp.push_back( devtestT[0] ); // this is the vector that is ranked
1553 vtmp.push_back( devtrainT[0] );
1554 vtmp.push_back( biastestT[0] );
1555 vtmp.push_back( biastrainT[0]);
1556 vtmp.push_back( rmstestT[0] );
1557 vtmp.push_back( rmstrainT[0] );
1558 vtmp.push_back( minftestT[0] );
1559 vtmp.push_back( minftrainT[0]);
1560 gTools().UsefulSortAscending( vtmp, &vtemps );
1561 mname[0] = vtemps;
1562 devtest[0] = vtmp[0];
1563 devtrain[0] = vtmp[1];
1564 biastest[0] = vtmp[2];
1565 biastrain[0] = vtmp[3];
1566 rmstest[0] = vtmp[4];
1567 rmstrain[0] = vtmp[5];
1568 minftest[0] = vtmp[6];
1569 minftrain[0] = vtmp[7];
1570 rhotest[0] = vtmp[8];
1571 rhotrain[0] = vtmp[9];
1572 devtestT[0] = vtmp[10];
1573 devtrainT[0] = vtmp[11];
1574 biastestT[0] = vtmp[12];
1575 biastrainT[0] = vtmp[13];
1576 rmstestT[0] = vtmp[14];
1577 rmstrainT[0] = vtmp[15];
1578 minftestT[0] = vtmp[16];
1579 minftrainT[0] = vtmp[17];
1580 } else if (doMulticlass) {
1581 // TODO: fill in something meaningful
1582 // If there is some ranking of methods to be done it should be done here.
1583 // However, this is not so easy to define for multiclass so it is left out for now.
1584
1585 }
1586 else {
1587 // now sort the variables according to the best 'eff at Beff=0.10'
1588 for (Int_t k=0; k<2; k++) {
1589 std::vector< std::vector<Double_t> > vtemp;
1590 vtemp.push_back( effArea[k] ); // this is the vector that is ranked
1591 vtemp.push_back( eff10[k] );
1592 vtemp.push_back( eff01[k] );
1593 vtemp.push_back( eff30[k] );
1594 vtemp.push_back( eff10err[k] );
1595 vtemp.push_back( eff01err[k] );
1596 vtemp.push_back( eff30err[k] );
1597 vtemp.push_back( trainEff10[k] );
1598 vtemp.push_back( trainEff01[k] );
1599 vtemp.push_back( trainEff30[k] );
1600 vtemp.push_back( sig[k] );
1601 vtemp.push_back( sep[k] );
1602 vtemp.push_back( roc[k] );
1603 std::vector<TString> vtemps = mname[k];
1604 gTools().UsefulSortDescending( vtemp, &vtemps );
1605 effArea[k] = vtemp[0];
1606 eff10[k] = vtemp[1];
1607 eff01[k] = vtemp[2];
1608 eff30[k] = vtemp[3];
1609 eff10err[k] = vtemp[4];
1610 eff01err[k] = vtemp[5];
1611 eff30err[k] = vtemp[6];
1612 trainEff10[k] = vtemp[7];
1613 trainEff01[k] = vtemp[8];
1614 trainEff30[k] = vtemp[9];
1615 sig[k] = vtemp[10];
1616 sep[k] = vtemp[11];
1617 roc[k] = vtemp[12];
1618 mname[k] = vtemps;
1619 }
1620 }
1621
1622 // -----------------------------------------------------------------------
1623 // Second part of evaluation process
1624 // --> compute correlations among MVAs
1625 // --> compute correlations between input variables and MVA (determines importance)
1626 // --> count overlaps
1627 // -----------------------------------------------------------------------
1628 if(fCorrelations)
1629 {
1630 const Int_t nmeth = methodsNoCuts.size();
1631 MethodBase* method = dynamic_cast<MethodBase*>(methods[0][0]);
1632 const Int_t nvar = method->fDataSetInfo.GetNVariables();
1633 if (!doRegression && !doMulticlass ) {
1634
1635 if (nmeth > 0) {
1636
1637 // needed for correlations
1638 Double_t *dvec = new Double_t[nmeth+nvar];
1639 std::vector<Double_t> rvec;
1640
1641 // for correlations
1642 TPrincipal* tpSig = new TPrincipal( nmeth+nvar, "" );
1643 TPrincipal* tpBkg = new TPrincipal( nmeth+nvar, "" );
1644
1645 // set required tree branch references
1646 Int_t ivar = 0;
1647 std::vector<TString>* theVars = new std::vector<TString>;
1648 std::vector<ResultsClassification*> mvaRes;
1649 for (MVector::iterator itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end(); ++itrMethod, ++ivar) {
1650 MethodBase* m = dynamic_cast<MethodBase*>(*itrMethod);
1651 if(m==0) continue;
1652 theVars->push_back( m->GetTestvarName() );
1653 rvec.push_back( m->GetSignalReferenceCut() );
1654 theVars->back().ReplaceAll( "MVA_", "" );
1655 mvaRes.push_back( dynamic_cast<ResultsClassification*>( m->Data()->GetResults( m->GetMethodName(),
1658 }
1659
1660 // for overlap study
1661 TMatrixD* overlapS = new TMatrixD( nmeth, nmeth );
1662 TMatrixD* overlapB = new TMatrixD( nmeth, nmeth );
1663 (*overlapS) *= 0; // init...
1664 (*overlapB) *= 0; // init...
1665
1666 // loop over test tree
1667 DataSet* defDs = method->fDataSetInfo.GetDataSet();
1669 for (Int_t ievt=0; ievt<defDs->GetNEvents(); ievt++) {
1670 const Event* ev = defDs->GetEvent(ievt);
1671
1672 // for correlations
1673 TMatrixD* theMat = 0;
1674 for (Int_t im=0; im<nmeth; im++) {
1675 // check for NaN value
1676 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1677 if (TMath::IsNaN(retval)) {
1678 Log() << kWARNING << "Found NaN return value in event: " << ievt
1679 << " for method \"" << methodsNoCuts[im]->GetName() << "\"" << Endl;
1680 dvec[im] = 0;
1681 }
1682 else dvec[im] = retval;
1683 }
1684 for (Int_t iv=0; iv<nvar; iv++) dvec[iv+nmeth] = (Double_t)ev->GetValue(iv);
1685 if (method->fDataSetInfo.IsSignal(ev)) { tpSig->AddRow( dvec ); theMat = overlapS; }
1686 else { tpBkg->AddRow( dvec ); theMat = overlapB; }
1687
1688 // count overlaps
1689 for (Int_t im=0; im<nmeth; im++) {
1690 for (Int_t jm=im; jm<nmeth; jm++) {
1691 if ((dvec[im] - rvec[im])*(dvec[jm] - rvec[jm]) > 0) {
1692 (*theMat)(im,jm)++;
1693 if (im != jm) (*theMat)(jm,im)++;
1694 }
1695 }
1696 }
1697 }
1698
1699 // renormalise overlap matrix
1700 (*overlapS) *= (1.0/defDs->GetNEvtSigTest()); // init...
1701 (*overlapB) *= (1.0/defDs->GetNEvtBkgdTest()); // init...
1702
1703 tpSig->MakePrincipals();
1704 tpBkg->MakePrincipals();
1705
1706 const TMatrixD* covMatS = tpSig->GetCovarianceMatrix();
1707 const TMatrixD* covMatB = tpBkg->GetCovarianceMatrix();
1708
1709 const TMatrixD* corrMatS = gTools().GetCorrelationMatrix( covMatS );
1710 const TMatrixD* corrMatB = gTools().GetCorrelationMatrix( covMatB );
1711
1712 // print correlation matrices
1713 if (corrMatS != 0 && corrMatB != 0) {
1714
1715 // extract MVA matrix
1716 TMatrixD mvaMatS(nmeth,nmeth);
1717 TMatrixD mvaMatB(nmeth,nmeth);
1718 for (Int_t im=0; im<nmeth; im++) {
1719 for (Int_t jm=0; jm<nmeth; jm++) {
1720 mvaMatS(im,jm) = (*corrMatS)(im,jm);
1721 mvaMatB(im,jm) = (*corrMatB)(im,jm);
1722 }
1723 }
1724
1725 // extract variables - to MVA matrix
1726 std::vector<TString> theInputVars;
1727 TMatrixD varmvaMatS(nvar,nmeth);
1728 TMatrixD varmvaMatB(nvar,nmeth);
1729 for (Int_t iv=0; iv<nvar; iv++) {
1730 theInputVars.push_back( method->fDataSetInfo.GetVariableInfo( iv ).GetLabel() );
1731 for (Int_t jm=0; jm<nmeth; jm++) {
1732 varmvaMatS(iv,jm) = (*corrMatS)(nmeth+iv,jm);
1733 varmvaMatB(iv,jm) = (*corrMatB)(nmeth+iv,jm);
1734 }
1735 }
1736
1737 if (nmeth > 1) {
1738 Log() << kINFO << Endl;
1739 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA correlation matrix (signal):" << Endl;
1740 gTools().FormattedOutput( mvaMatS, *theVars, Log() );
1741 Log() << kINFO << Endl;
1742
1743 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA correlation matrix (background):" << Endl;
1744 gTools().FormattedOutput( mvaMatB, *theVars, Log() );
1745 Log() << kINFO << Endl;
1746 }
1747
1748 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Correlations between input variables and MVA response (signal):" << Endl;
1749 gTools().FormattedOutput( varmvaMatS, theInputVars, *theVars, Log() );
1750 Log() << kINFO << Endl;
1751
1752 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Correlations between input variables and MVA response (background):" << Endl;
1753 gTools().FormattedOutput( varmvaMatB, theInputVars, *theVars, Log() );
1754 Log() << kINFO << Endl;
1755 }
1756 else Log() << kWARNING <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "<TestAllMethods> cannot compute correlation matrices" << Endl;
1757
1758 // print overlap matrices
1759 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1760 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1761 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1762 gTools().FormattedOutput( rvec, *theVars, "Method" , "Cut value", Log() );
1763 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1764
1765 // give notice that cut method has been excluded from this test
1766 if (nmeth != (Int_t)methods->size())
1767 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Note: no correlations and overlap with cut method are provided at present" << Endl;
1768
1769 if (nmeth > 1) {
1770 Log() << kINFO << Endl;
1771 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA overlap matrix (signal):" << Endl;
1772 gTools().FormattedOutput( *overlapS, *theVars, Log() );
1773 Log() << kINFO << Endl;
1774
1775 Log() << kINFO <<Form("Dataset[%s] : ",method->fDataSetInfo.GetName())<< "Inter-MVA overlap matrix (background):" << Endl;
1776 gTools().FormattedOutput( *overlapB, *theVars, Log() );
1777 }
1778
1779 // cleanup
1780 delete tpSig;
1781 delete tpBkg;
1782 delete corrMatS;
1783 delete corrMatB;
1784 delete theVars;
1785 delete overlapS;
1786 delete overlapB;
1787 delete [] dvec;
1788 }
1789 }
1790 }
1791 // -----------------------------------------------------------------------
1792 // Third part of evaluation process
1793 // --> output
1794 // -----------------------------------------------------------------------
1795
1796 if (doRegression) {
1797
1798 Log() << kINFO << Endl;
1799 TString hLine = "--------------------------------------------------------------------------------------------------";
1800 Log() << kINFO << "Evaluation results ranked by smallest RMS on test sample:" << Endl;
1801 Log() << kINFO << "(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1802 Log() << kINFO << " \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1803 Log() << kINFO << " Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1804 Log() << kINFO << " tained when removing events deviating more than 2sigma from average.)" << Endl;
1805 Log() << kINFO << hLine << Endl;
1806 //Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1807 Log() << kINFO << hLine << Endl;
1808
1809 for (Int_t i=0; i<nmeth_used[0]; i++) {
1810 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
1811 if(theMethod==0) continue;
1812
1813 Log() << kINFO << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1814 theMethod->fDataSetInfo.GetName(),
1815 (const char*)mname[0][i],
1816 biastest[0][i], biastestT[0][i],
1817 rmstest[0][i], rmstestT[0][i],
1818 minftest[0][i], minftestT[0][i] )
1819 << Endl;
1820 }
1821 Log() << kINFO << hLine << Endl;
1822 Log() << kINFO << Endl;
1823 Log() << kINFO << "Evaluation results ranked by smallest RMS on training sample:" << Endl;
1824 Log() << kINFO << "(overtraining check)" << Endl;
1825 Log() << kINFO << hLine << Endl;
1826 Log() << kINFO << "DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1827 Log() << kINFO << hLine << Endl;
1828
1829 for (Int_t i=0; i<nmeth_used[0]; i++) {
1830 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
1831 if(theMethod==0) continue;
1832 Log() << kINFO << Form("%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1833 theMethod->fDataSetInfo.GetName(),
1834 (const char*)mname[0][i],
1835 biastrain[0][i], biastrainT[0][i],
1836 rmstrain[0][i], rmstrainT[0][i],
1837 minftrain[0][i], minftrainT[0][i] )
1838 << Endl;
1839 }
1840 Log() << kINFO << hLine << Endl;
1841 Log() << kINFO << Endl;
1842 } else if (doMulticlass) {
1843 // ====================================================================
1844 // === Multiclass Output
1845 // ====================================================================
1846
1847 TString hLine =
1848 "-------------------------------------------------------------------------------------------------------";
1849
1850 // This part uses a genetic alg. to evaluate the optimal sig eff * sig pur.
1851 // This is why it is disabled for now.
1852 //
1853 // // --- Acheivable signal efficiency * signal purity
1854 // // --------------------------------------------------------------------
1855 // Log() << kINFO << Endl;
1856 // Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
1857 // Log() << kINFO << hLine << Endl;
1858
1859 // // iterate over methods and evaluate
1860 // for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); itrMethod++) {
1861 // MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
1862 // if (theMethod == 0) {
1863 // continue;
1864 // }
1865
1866 // TString header = "DataSet Name MVA Method ";
1867 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1868 // header += Form("%-12s ", theMethod->fDataSetInfo.GetClassInfo(icls)->GetName());
1869 // }
1870
1871 // Log() << kINFO << header << Endl;
1872 // Log() << kINFO << hLine << Endl;
1873 // for (Int_t i = 0; i < nmeth_used[0]; i++) {
1874 // TString res = Form("[%-14s] %-15s", theMethod->fDataSetInfo.GetName(), (const char *)mname[0][i]);
1875 // for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
1876 // res += Form("%#1.3f ", (multiclass_testEff[i][icls]) * (multiclass_testPur[i][icls]));
1877 // }
1878 // Log() << kINFO << res << Endl;
1879 // }
1880
1881 // Log() << kINFO << hLine << Endl;
1882 // Log() << kINFO << Endl;
1883 // }
1884
1885 // --- 1 vs Rest ROC AUC, signal efficiency @ given background efficiency
1886 // --------------------------------------------------------------------
1887 TString header1 = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "Dataset", "MVA Method", "ROC AUC", "Sig eff@B=0.01",
1888 "Sig eff@B=0.10", "Sig eff@B=0.30");
1889 TString header2 = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "Name:", "/ Class:", "test (train)", "test (train)",
1890 "test (train)", "test (train)");
1891 Log() << kINFO << Endl;
1892 Log() << kINFO << "1-vs-rest performance metrics per class" << Endl;
1893 Log() << kINFO << hLine << Endl;
1894 Log() << kINFO << Endl;
1895 Log() << kINFO << "Considers the listed class as signal and the other classes" << Endl;
1896 Log() << kINFO << "as background, reporting the resulting binary performance." << Endl;
1897 Log() << kINFO << "A score of 0.820 (0.850) means 0.820 was acheived on the" << Endl;
1898 Log() << kINFO << "test set and 0.850 on the training set." << Endl;
1899
1900 Log() << kINFO << Endl;
1901 Log() << kINFO << header1 << Endl;
1902 Log() << kINFO << header2 << Endl;
1903 for (Int_t k = 0; k < 2; k++) {
1904 for (Int_t i = 0; i < nmeth_used[k]; i++) {
1905 if (k == 1) {
1906 mname[k][i].ReplaceAll("Variable_", "");
1907 }
1908
1909 const TString datasetName = itrMap->first;
1910 const TString mvaName = mname[k][i];
1911
1912 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, mvaName));
1913 if (theMethod == 0) {
1914 continue;
1915 }
1916
1917 Log() << kINFO << Endl;
1918 TString row = Form("%-15s%-15s", datasetName.Data(), mvaName.Data());
1919 Log() << kINFO << row << Endl;
1920 Log() << kINFO << "------------------------------" << Endl;
1921
1922 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
1923 for (UInt_t iClass = 0; iClass < numClasses; ++iClass) {
1924
1925 ROCCurve *rocCurveTrain = GetROC(datasetName, mvaName, iClass, Types::kTraining);
1926 ROCCurve *rocCurveTest = GetROC(datasetName, mvaName, iClass, Types::kTesting);
1927
1928 const TString className = theMethod->DataInfo().GetClassInfo(iClass)->GetName();
1929 const Double_t rocaucTrain = rocCurveTrain->GetROCIntegral();
1930 const Double_t effB01Train = rocCurveTrain->GetEffSForEffB(0.01);
1931 const Double_t effB10Train = rocCurveTrain->GetEffSForEffB(0.10);
1932 const Double_t effB30Train = rocCurveTrain->GetEffSForEffB(0.30);
1933 const Double_t rocaucTest = rocCurveTest->GetROCIntegral();
1934 const Double_t effB01Test = rocCurveTest->GetEffSForEffB(0.01);
1935 const Double_t effB10Test = rocCurveTest->GetEffSForEffB(0.10);
1936 const Double_t effB30Test = rocCurveTest->GetEffSForEffB(0.30);
1937 const TString rocaucCmp = Form("%5.3f (%5.3f)", rocaucTest, rocaucTrain);
1938 const TString effB01Cmp = Form("%5.3f (%5.3f)", effB01Test, effB01Train);
1939 const TString effB10Cmp = Form("%5.3f (%5.3f)", effB10Test, effB10Train);
1940 const TString effB30Cmp = Form("%5.3f (%5.3f)", effB30Test, effB30Train);
1941 row = Form("%-15s%-15s%-15s%-15s%-15s%-15s", "", className.Data(), rocaucCmp.Data(), effB01Cmp.Data(),
1942 effB10Cmp.Data(), effB30Cmp.Data());
1943 Log() << kINFO << row << Endl;
1944
1945 delete rocCurveTrain;
1946 delete rocCurveTest;
1947 }
1948 }
1949 }
1950 Log() << kINFO << Endl;
1951 Log() << kINFO << hLine << Endl;
1952 Log() << kINFO << Endl;
1953
1954 // --- Confusion matrices
1955 // --------------------------------------------------------------------
1956 auto printMatrix = [](TMatrixD const &matTraining, TMatrixD const &matTesting, std::vector<TString> classnames,
1957 UInt_t numClasses, MsgLogger &stream) {
1958 // assert (classLabledWidth >= valueLabelWidth + 2)
1959 // if (...) {Log() << kWARN << "..." << Endl; }
1960
1961 // TODO: Ensure matrices are same size.
1962
1963 TString header = Form(" %-14s", " ");
1964 TString headerInfo = Form(" %-14s", " ");
1965 ;
1966 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
1967 header += Form(" %-14s", classnames[iCol].Data());
1968 headerInfo += Form(" %-14s", " test (train)");
1969 }
1970 stream << kINFO << header << Endl;
1971 stream << kINFO << headerInfo << Endl;
1972
1973 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
1974 stream << kINFO << Form(" %-14s", classnames[iRow].Data());
1975
1976 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
1977 if (iCol == iRow) {
1978 stream << kINFO << Form(" %-14s", "-");
1979 } else {
1980 Double_t trainValue = matTraining[iRow][iCol];
1981 Double_t testValue = matTesting[iRow][iCol];
1982 TString entry = Form("%-5.3f (%-5.3f)", testValue, trainValue);
1983 stream << kINFO << Form(" %-14s", entry.Data());
1984 }
1985 }
1986 stream << kINFO << Endl;
1987 }
1988 };
1989
1990 Log() << kINFO << Endl;
1991 Log() << kINFO << "Confusion matrices for all methods" << Endl;
1992 Log() << kINFO << hLine << Endl;
1993 Log() << kINFO << Endl;
1994 Log() << kINFO << "Does a binary comparison between the two classes given by a " << Endl;
1995 Log() << kINFO << "particular row-column combination. In each case, the class " << Endl;
1996 Log() << kINFO << "given by the row is considered signal while the class given " << Endl;
1997 Log() << kINFO << "by the column index is considered background." << Endl;
1998 Log() << kINFO << Endl;
1999 for (UInt_t iMethod = 0; iMethod < methods->size(); ++iMethod) {
2000 MethodBase *theMethod = dynamic_cast<MethodBase *>(methods->at(iMethod));
2001 if (theMethod == nullptr) {
2002 continue;
2003 }
2004 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
2005
2006 std::vector<TString> classnames;
2007 for (UInt_t iCls = 0; iCls < numClasses; ++iCls) {
2008 classnames.push_back(theMethod->fDataSetInfo.GetClassInfo(iCls)->GetName());
2009 }
2010 Log() << kINFO
2011 << "=== Showing confusion matrix for method : " << Form("%-15s", (const char *)mname[0][iMethod])
2012 << Endl;
2013 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.01%)" << Endl;
2014 Log() << kINFO << "---------------------------------------------------" << Endl;
2015 printMatrix(multiclass_testConfusionEffB01[iMethod], multiclass_trainConfusionEffB01[iMethod], classnames,
2016 numClasses, Log());
2017 Log() << kINFO << Endl;
2018
2019 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.10%)" << Endl;
2020 Log() << kINFO << "---------------------------------------------------" << Endl;
2021 printMatrix(multiclass_testConfusionEffB10[iMethod], multiclass_trainConfusionEffB10[iMethod], classnames,
2022 numClasses, Log());
2023 Log() << kINFO << Endl;
2024
2025 Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.30%)" << Endl;
2026 Log() << kINFO << "---------------------------------------------------" << Endl;
2027 printMatrix(multiclass_testConfusionEffB30[iMethod], multiclass_trainConfusionEffB30[iMethod], classnames,
2028 numClasses, Log());
2029 Log() << kINFO << Endl;
2030 }
2031 Log() << kINFO << hLine << Endl;
2032 Log() << kINFO << Endl;
2033
2034 } else {
2035 // Binary classification
2036 if (fROC) {
2037 Log().EnableOutput();
2039 Log() << Endl;
2040 TString hLine = "------------------------------------------------------------------------------------------"
2041 "-------------------------";
2042 Log() << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
2043 Log() << kINFO << hLine << Endl;
2044 Log() << kINFO << "DataSet MVA " << Endl;
2045 Log() << kINFO << "Name: Method: ROC-integ" << Endl;
2046
2047 // Log() << kDEBUG << "DataSet MVA Signal efficiency at bkg eff.(error):
2048 // | Sepa- Signifi- " << Endl; Log() << kDEBUG << "Name: Method: @B=0.01
2049 // @B=0.10 @B=0.30 ROC-integ ROCCurve| ration: cance: " << Endl;
2050 Log() << kDEBUG << hLine << Endl;
2051 for (Int_t k = 0; k < 2; k++) {
2052 if (k == 1 && nmeth_used[k] > 0) {
2053 Log() << kINFO << hLine << Endl;
2054 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2055 }
2056 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2057 TString datasetName = itrMap->first;
2058 TString methodName = mname[k][i];
2059
2060 if (k == 1) {
2061 methodName.ReplaceAll("Variable_", "");
2062 }
2063
2064 MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, methodName));
2065 if (theMethod == 0) {
2066 continue;
2067 }
2068
2069 TMVA::DataSet *dataset = theMethod->Data();
2070 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
2071 std::vector<Bool_t> *mvaResType =
2072 dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
2073
2074 Double_t rocIntegral = 0.0;
2075 if (mvaResType->size() != 0) {
2076 rocIntegral = GetROCIntegral(datasetName, methodName);
2077 }
2078
2079 if (sep[k][i] < 0 || sig[k][i] < 0) {
2080 // cannot compute separation/significance -> no MVA (usually for Cuts)
2081 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), effArea[k][i])
2082 << Endl;
2083
2084 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2085 // %#1.3f %#1.3f | -- --",
2086 // datasetName.Data(),
2087 // methodName.Data(),
2088 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2089 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2090 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2091 // effArea[k][i],rocIntegral) << Endl;
2092 } else {
2093 Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
2094 << Endl;
2095 // Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
2096 // %#1.3f %#1.3f | %#1.3f %#1.3f",
2097 // datasetName.Data(),
2098 // methodName.Data(),
2099 // eff01[k][i], Int_t(1000*eff01err[k][i]),
2100 // eff10[k][i], Int_t(1000*eff10err[k][i]),
2101 // eff30[k][i], Int_t(1000*eff30err[k][i]),
2102 // effArea[k][i],rocIntegral,
2103 // sep[k][i], sig[k][i]) << Endl;
2104 }
2105 }
2106 }
2107 Log() << kINFO << hLine << Endl;
2108 Log() << kINFO << Endl;
2109 Log() << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
2110 Log() << kINFO << hLine << Endl;
2111 Log() << kINFO
2112 << "DataSet MVA Signal efficiency: from test sample (from training sample) "
2113 << Endl;
2114 Log() << kINFO << "Name: Method: @B=0.01 @B=0.10 @B=0.30 "
2115 << Endl;
2116 Log() << kINFO << hLine << Endl;
2117 for (Int_t k = 0; k < 2; k++) {
2118 if (k == 1 && nmeth_used[k] > 0) {
2119 Log() << kINFO << hLine << Endl;
2120 Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
2121 }
2122 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2123 if (k == 1) mname[k][i].ReplaceAll("Variable_", "");
2124 MethodBase *theMethod = dynamic_cast<MethodBase *>((*methods)[i]);
2125 if (theMethod == 0) continue;
2126
2127 Log() << kINFO << Form("%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
2128 theMethod->fDataSetInfo.GetName(), (const char *)mname[k][i], eff01[k][i],
2129 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
2130 << Endl;
2131 }
2132 }
2133 Log() << kINFO << hLine << Endl;
2134 Log() << kINFO << Endl;
2135
2136 if (gTools().CheckForSilentOption(GetOptions())) Log().InhibitOutput();
2137 } // end fROC
2138 }
2139 if(!IsSilentFile())
2140 {
2141 std::list<TString> datasets;
2142 for (Int_t k=0; k<2; k++) {
2143 for (Int_t i=0; i<nmeth_used[k]; i++) {
2144 MethodBase* theMethod = dynamic_cast<MethodBase*>((*methods)[i]);
2145 if(theMethod==0) continue;
2146 // write test/training trees
2147 RootBaseDir()->cd(theMethod->fDataSetInfo.GetName());
2148 if(std::find(datasets.begin(), datasets.end(), theMethod->fDataSetInfo.GetName()) == datasets.end())
2149 {
2152 datasets.push_back(theMethod->fDataSetInfo.GetName());
2153 }
2154 }
2155 }
2156 }
2157 }//end for MethodsMap
2158 // references for citation
2160}
2161
2162////////////////////////////////////////////////////////////////////////////////
2163/// Evaluate Variable Importance
2164
2165TH1F* TMVA::Factory::EvaluateImportance(DataLoader *loader,VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2166{
2167 fModelPersistence=kFALSE;
2168 fSilentFile=kTRUE;//we need silent file here because we need fast classification results
2169
2170 //getting number of variables and variable names from loader
2171 const int nbits = loader->GetDataSetInfo().GetNVariables();
2172 if(vitype==VIType::kShort)
2173 return EvaluateImportanceShort(loader,theMethod,methodTitle,theOption);
2174 else if(vitype==VIType::kAll)
2175 return EvaluateImportanceAll(loader,theMethod,methodTitle,theOption);
2176 else if(vitype==VIType::kRandom&&nbits>10)
2177 {
2178 return EvaluateImportanceRandom(loader,pow(2,nbits),theMethod,methodTitle,theOption);
2179 }else
2180 {
2181 std::cerr<<"Error in Variable Importance: Random mode require more that 10 variables in the dataset."<<std::endl;
2182 return nullptr;
2183 }
2184}
2185
2186////////////////////////////////////////////////////////////////////////////////
2187
2188TH1F* TMVA::Factory::EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2189{
2190
2191 uint64_t x = 0;
2192 uint64_t y = 0;
2193
2194 //getting number of variables and variable names from loader
2195 const int nbits = loader->GetDataSetInfo().GetNVariables();
2196 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2197
2198 uint64_t range = pow(2, nbits);
2199
2200 //vector to save importances
2201 std::vector<Double_t> importances(nbits);
2202 //vector to save ROC
2203 std::vector<Double_t> ROC(range);
2204 ROC[0]=0.5;
2205 for (int i = 0; i < nbits; i++)importances[i] = 0;
2206
2207 Double_t SROC, SSROC; //computed ROC value
2208 for ( x = 1; x <range ; x++) {
2209
2210 std::bitset<VIBITS> xbitset(x);
2211 if (x == 0) continue; //data loader need at least one variable
2212
2213 //creating loader for seed
2214 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2215
2216 //adding variables from seed
2217 for (int index = 0; index < nbits; index++) {
2218 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2219 }
2220
2221 DataLoaderCopy(seedloader,loader);
2222 seedloader->PrepareTrainingAndTestTree(loader->GetDataSetInfo().GetCut("Signal"), loader->GetDataSetInfo().GetCut("Background"), loader->GetDataSetInfo().GetSplitOptions());
2223
2224 //Booking Seed
2225 BookMethod(seedloader, theMethod, methodTitle, theOption);
2226
2227 //Train/Test/Evaluation
2228 TrainAllMethods();
2229 TestAllMethods();
2230 EvaluateAllMethods();
2231
2232 //getting ROC
2233 ROC[x] = GetROCIntegral(xbitset.to_string(), methodTitle);
2234
2235 //cleaning information to process sub-seeds
2236 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2238 delete sresults;
2239 delete seedloader;
2240 this->DeleteAllMethods();
2241
2242 fMethodsMap.clear();
2243 //removing global result because it is requiring a lot of RAM for all seeds
2244 }
2245
2246
2247 for ( x = 0; x <range ; x++)
2248 {
2249 SROC=ROC[x];
2250 for (uint32_t i = 0; i < VIBITS; ++i) {
2251 if (x & (uint64_t(1) << i)) {
2252 y = x & ~(1 << i);
2253 std::bitset<VIBITS> ybitset(y);
2254 //need at least one variable
2255 //NOTE: if sub-seed is zero then is the special case
2256 //that count in xbitset is 1
2257 Double_t ny = log(x - y) / 0.693147;
2258 if (y == 0) {
2259 importances[ny] = SROC - 0.5;
2260 continue;
2261 }
2262
2263 //getting ROC
2264 SSROC = ROC[y];
2265 importances[ny] += SROC - SSROC;
2266 //cleaning information
2267 }
2268
2269 }
2270 }
2271 std::cout<<"--- Variable Importance Results (All)"<<std::endl;
2272 return GetImportance(nbits,importances,varNames);
2273}
2274
2275static long int sum(long int i)
2276{
2277 long int _sum=0;
2278 for(long int n=0;n<i;n++) _sum+=pow(2,n);
2279 return _sum;
2280}
2281
2282////////////////////////////////////////////////////////////////////////////////
2283
2284TH1F* TMVA::Factory::EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2285{
2286 uint64_t x = 0;
2287 uint64_t y = 0;
2288
2289 //getting number of variables and variable names from loader
2290 const int nbits = loader->GetDataSetInfo().GetNVariables();
2291 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2292
2293 long int range = sum(nbits);
2294// std::cout<<range<<std::endl;
2295 //vector to save importances
2296 std::vector<Double_t> importances(nbits);
2297 for (int i = 0; i < nbits; i++)importances[i] = 0;
2298
2299 Double_t SROC, SSROC; //computed ROC value
2300
2301 x = range;
2302
2303 std::bitset<VIBITS> xbitset(x);
2304 if (x == 0) Log()<<kFATAL<<"Error: need at least one variable."; //data loader need at least one variable
2305
2306
2307 //creating loader for seed
2308 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2309
2310 //adding variables from seed
2311 for (int index = 0; index < nbits; index++) {
2312 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2313 }
2314
2315 //Loading Dataset
2316 DataLoaderCopy(seedloader,loader);
2317
2318 //Booking Seed
2319 BookMethod(seedloader, theMethod, methodTitle, theOption);
2320
2321 //Train/Test/Evaluation
2322 TrainAllMethods();
2323 TestAllMethods();
2324 EvaluateAllMethods();
2325
2326 //getting ROC
2327 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2328
2329 //cleaning information to process sub-seeds
2330 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2332 delete sresults;
2333 delete seedloader;
2334 this->DeleteAllMethods();
2335 fMethodsMap.clear();
2336
2337 //removing global result because it is requiring a lot of RAM for all seeds
2338
2339 for (uint32_t i = 0; i < VIBITS; ++i) {
2340 if (x & (1 << i)) {
2341 y = x & ~(uint64_t(1) << i);
2342 std::bitset<VIBITS> ybitset(y);
2343 //need at least one variable
2344 //NOTE: if sub-seed is zero then is the special case
2345 //that count in xbitset is 1
2346 Double_t ny = log(x - y) / 0.693147;
2347 if (y == 0) {
2348 importances[ny] = SROC - 0.5;
2349 continue;
2350 }
2351
2352 //creating loader for sub-seed
2353 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2354 //adding variables from sub-seed
2355 for (int index = 0; index < nbits; index++) {
2356 if (ybitset[index]) subseedloader->AddVariable(varNames[index], 'F');
2357 }
2358
2359 //Loading Dataset
2360 DataLoaderCopy(subseedloader,loader);
2361
2362 //Booking SubSeed
2363 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2364
2365 //Train/Test/Evaluation
2366 TrainAllMethods();
2367 TestAllMethods();
2368 EvaluateAllMethods();
2369
2370 //getting ROC
2371 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2372 importances[ny] += SROC - SSROC;
2373
2374 //cleaning information
2375 TMVA::MethodBase *ssmethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2377 delete ssresults;
2378 delete subseedloader;
2379 this->DeleteAllMethods();
2380 fMethodsMap.clear();
2381 }
2382 }
2383 std::cout<<"--- Variable Importance Results (Short)"<<std::endl;
2384 return GetImportance(nbits,importances,varNames);
2385}
2386
2387////////////////////////////////////////////////////////////////////////////////
2388
2389TH1F* TMVA::Factory::EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption)
2390{
2391 TRandom3 *rangen = new TRandom3(0); //Random Gen.
2392
2393 uint64_t x = 0;
2394 uint64_t y = 0;
2395
2396 //getting number of variables and variable names from loader
2397 const int nbits = loader->GetDataSetInfo().GetNVariables();
2398 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2399
2400 long int range = pow(2, nbits);
2401
2402 //vector to save importances
2403 std::vector<Double_t> importances(nbits);
2404 Double_t importances_norm = 0;
2405 for (int i = 0; i < nbits; i++)importances[i] = 0;
2406
2407 Double_t SROC, SSROC; //computed ROC value
2408 for (UInt_t n = 0; n < nseeds; n++) {
2409 x = rangen -> Integer(range);
2410
2411 std::bitset<32> xbitset(x);
2412 if (x == 0) continue; //data loader need at least one variable
2413
2414
2415 //creating loader for seed
2416 TMVA::DataLoader *seedloader = new TMVA::DataLoader(xbitset.to_string());
2417
2418 //adding variables from seed
2419 for (int index = 0; index < nbits; index++) {
2420 if (xbitset[index]) seedloader->AddVariable(varNames[index], 'F');
2421 }
2422
2423 //Loading Dataset
2424 DataLoaderCopy(seedloader,loader);
2425
2426 //Booking Seed
2427 BookMethod(seedloader, theMethod, methodTitle, theOption);
2428
2429 //Train/Test/Evaluation
2430 TrainAllMethods();
2431 TestAllMethods();
2432 EvaluateAllMethods();
2433
2434 //getting ROC
2435 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2436// std::cout << "Seed: n " << n << " x " << x << " xbitset:" << xbitset << " ROC " << SROC << std::endl;
2437
2438 //cleaning information to process sub-seeds
2439 TMVA::MethodBase *smethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2441 delete sresults;
2442 delete seedloader;
2443 this->DeleteAllMethods();
2444 fMethodsMap.clear();
2445
2446 //removing global result because it is requiring a lot of RAM for all seeds
2447
2448 for (uint32_t i = 0; i < 32; ++i) {
2449 if (x & (uint64_t(1) << i)) {
2450 y = x & ~(1 << i);
2451 std::bitset<32> ybitset(y);
2452 //need at least one variable
2453 //NOTE: if sub-seed is zero then is the special case
2454 //that count in xbitset is 1
2455 Double_t ny = log(x - y) / 0.693147;
2456 if (y == 0) {
2457 importances[ny] = SROC - 0.5;
2458 importances_norm += importances[ny];
2459 // std::cout << "SubSeed: " << y << " y:" << ybitset << "ROC " << 0.5 << std::endl;
2460 continue;
2461 }
2462
2463 //creating loader for sub-seed
2464 TMVA::DataLoader *subseedloader = new TMVA::DataLoader(ybitset.to_string());
2465 //adding variables from sub-seed
2466 for (int index = 0; index < nbits; index++) {
2467 if (ybitset[index]) subseedloader->AddVariable(varNames[index], 'F');
2468 }
2469
2470 //Loading Dataset
2471 DataLoaderCopy(subseedloader,loader);
2472
2473 //Booking SubSeed
2474 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2475
2476 //Train/Test/Evaluation
2477 TrainAllMethods();
2478 TestAllMethods();
2479 EvaluateAllMethods();
2480
2481 //getting ROC
2482 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2483 importances[ny] += SROC - SSROC;
2484 //std::cout << "SubSeed: " << y << " y:" << ybitset << " x-y " << x - y << " " << std::bitset<32>(x - y) << " ny " << ny << " SROC " << SROC << " SSROC " << SSROC << " Importance = " << importances[ny] << std::endl;
2485 //cleaning information
2486 TMVA::MethodBase *ssmethod=dynamic_cast<TMVA::MethodBase*>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2488 delete ssresults;
2489 delete subseedloader;
2490 this->DeleteAllMethods();
2491 fMethodsMap.clear();
2492 }
2493 }
2494 }
2495 std::cout<<"--- Variable Importance Results (Random)"<<std::endl;
2496 return GetImportance(nbits,importances,varNames);
2497}
2498
2499////////////////////////////////////////////////////////////////////////////////
2500
2501TH1F* TMVA::Factory::GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames)
2502{
2503 TH1F *vih1 = new TH1F("vih1", "", nbits, 0, nbits);
2504
2505 gStyle->SetOptStat(000000);
2506
2507 Float_t normalization = 0.0;
2508 for (int i = 0; i < nbits; i++) {
2509 normalization = normalization + importances[i];
2510 }
2511
2512 Float_t roc = 0.0;
2513
2514 gStyle->SetTitleXOffset(0.4);
2515 gStyle->SetTitleXOffset(1.2);
2516
2517
2518 std::vector<Double_t> x_ie(nbits), y_ie(nbits);
2519 for (Int_t i = 1; i < nbits + 1; i++) {
2520 x_ie[i - 1] = (i - 1) * 1.;
2521 roc = 100.0 * importances[i - 1] / normalization;
2522 y_ie[i - 1] = roc;
2523 std::cout<<"--- "<<varNames[i-1]<<" = "<<roc<<" %"<<std::endl;
2524 vih1->GetXaxis()->SetBinLabel(i, varNames[i - 1].Data());
2525 vih1->SetBinContent(i, roc);
2526 }
2527 TGraph *g_ie = new TGraph(nbits + 2, &x_ie[0], &y_ie[0]);
2528 g_ie->SetTitle("");
2529
2530 vih1->LabelsOption("v >", "X");
2531 vih1->SetBarWidth(0.97);
2532 Int_t ca = TColor::GetColor("#006600");
2533 vih1->SetFillColor(ca);
2534 //Int_t ci = TColor::GetColor("#990000");
2535
2536 vih1->GetYaxis()->SetTitle("Importance (%)");
2537 vih1->GetYaxis()->SetTitleSize(0.045);
2538 vih1->GetYaxis()->CenterTitle();
2539 vih1->GetYaxis()->SetTitleOffset(1.24);
2540
2541 vih1->GetYaxis()->SetRangeUser(-7, 50);
2542 vih1->SetDirectory(0);
2543
2544// vih1->Draw("B");
2545 return vih1;
2546}
#define h(i)
Definition: RSha256.hxx:106
int Int_t
Definition: RtypesCore.h:43
const Bool_t kFALSE
Definition: RtypesCore.h:90
double Double_t
Definition: RtypesCore.h:57
float Float_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:89
#define ClassImp(name)
Definition: Rtypes.h:361
char name[80]
Definition: TGX11.cxx:109
int type
Definition: TGX11.cxx:120
double pow(double, double)
double log(double)
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
#define gROOT
Definition: TROOT.h:406
char * Form(const char *fmt,...)
R__EXTERN TStyle * gStyle
Definition: TStyle.h:410
R__EXTERN TSystem * gSystem
Definition: TSystem.h:556
virtual void SetTitleOffset(Float_t offset=1)
Set distance between the axis and the axis title.
Definition: TAttAxis.cxx:294
virtual void SetTitleSize(Float_t size=0.04)
Set size of axis title.
Definition: TAttAxis.cxx:304
virtual void SetFillColor(Color_t fcolor)
Set the fill area color.
Definition: TAttFill.h:37
virtual void SetBinLabel(Int_t bin, const char *label)
Set label for bin.
Definition: TAxis.cxx:820
void CenterTitle(Bool_t center=kTRUE)
Center axis title.
Definition: TAxis.h:184
virtual void SetRangeUser(Double_t ufirst, Double_t ulast)
Set the viewing range for the axis from ufirst to ulast (in user coordinates).
Definition: TAxis.cxx:939
The Canvas class.
Definition: TCanvas.h:27
static Int_t GetColor(const char *hexcolor)
Static method returning color number for color specified by hex color string of form: "#rrggbb",...
Definition: TColor.cxx:1769
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:53
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
virtual void SetTitle(const char *title="")
Change (i.e.
Definition: TGraph.cxx:2324
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
virtual void SetDirectory(TDirectory *dir)
By default when an histogram is created, it is added to the list of histogram objects in the current ...
Definition: TH1.cxx:8393
virtual void SetTitle(const char *title)
See GetStatOverflows for more information.
Definition: TH1.cxx:6345
virtual void LabelsOption(Option_t *option="h", Option_t *axis="X")
Set option(s) to draw axis with labels.
Definition: TH1.cxx:5222
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1226
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:316
TAxis * GetYaxis()
Definition: TH1.h:317
virtual void SetBarWidth(Float_t width=0.5)
Definition: TH1.h:356
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:8678
Service class for 2-Dim histogram classes.
Definition: TH2.h:30
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:124
TString fWeightFileDirPrefix
Definition: Config.h:123
void SetDrawProgressBar(Bool_t d)
Definition: Config.h:71
void SetUseColor(Bool_t uc)
Definition: Config.h:62
class TMVA::Config::VariablePlotting fVariablePlotting
void SetSilent(Bool_t s)
Definition: Config.h:65
IONames & GetIONames()
Definition: Config.h:100
void SetConfigDescription(const char *d)
Definition: Configurable.h:64
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
Definition: Configurable.h:168
void SetConfigName(const char *n)
Definition: Configurable.h:63
virtual void ParseOptions()
options parser
const TString & GetOptions() const
Definition: Configurable.h:84
MsgLogger & Log() const
Definition: Configurable.h:122
MsgLogger * fLogger
Definition: Configurable.h:128
void CheckForUnusedOptions() const
checks for unused options in option string
UInt_t GetEntries(const TString &name) const
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:633
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:139
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:486
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t GetNVariables() const
Definition: DataSetInfo.h:125
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:69
const TMatrixD * CorrelationMatrix(const TString &className) const
UInt_t GetNClasses() const
Definition: DataSetInfo.h:153
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:184
UInt_t GetNTargets() const
Definition: DataSetInfo.h:126
DataSet * GetDataSet() const
returns data set
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
std::vector< TString > GetListOfVariables() const
returns list of variables
ClassInfo * GetClassInfo(Int_t clNum) const
const TCut & GetCut(Int_t i) const
Definition: DataSetInfo.h:166
VariableInfo & GetVariableInfo(Int_t i)
Definition: DataSetInfo.h:103
Bool_t IsSignal(const Event *ev) const
DataSetManager * GetDataSetManager()
Definition: DataSetInfo.h:192
DataInputHandler & DataInput()
Class that contains all the data information.
Definition: DataSet.h:69
Long64_t GetNEvtSigTest()
return number of signal test events in dataset
Definition: DataSet.cxx:426
TTree * GetTree(Types::ETreeType type)
create the test/trainings tree with all the variables, the weights, the classes, the targets,...
Definition: DataSet.cxx:608
const Event * GetEvent() const
Definition: DataSet.cxx:201
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:217
Results * GetResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
Definition: DataSet.cxx:264
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:100
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:227
Long64_t GetNEvtBkgdTest()
return number of background test events in dataset
Definition: DataSet.cxx:434
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:391
This is the main MVA steering class.
Definition: Factory.h:81
void PrintHelpMessage(const TString &datasetname, const TString &methodTitle="") const
Print predefined help message of classifier.
Definition: Factory.cxx:1307
Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass=0)
Calculate the integral of the ROC curve, also known as the area under curve (AUC),...
Definition: Factory.cxx:842
Bool_t fCorrelations
verbosity level, controls granularity of logging
Definition: Factory.h:212
std::vector< IMethod * > MVector
Definition: Factory.h:85
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1093
Bool_t Verbose(void) const
Definition: Factory.h:135
void WriteDataInformation(DataSetInfo &fDataSetInfo)
Definition: Factory.cxx:596
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:345
Factory(TString theJobName, TFile *theTargetFile, TString theOption="")
Standard constructor.
Definition: Factory.cxx:118
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition: Factory.cxx:1244
Bool_t fVerbose
list of transformations to test
Definition: Factory.h:210
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition: Factory.cxx:1349
TH1F * EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2389
TH1F * GetImportance(const int nbits, std::vector< Double_t > importances, std::vector< TString > varNames)
Definition: Factory.cxx:2501
Bool_t fROC
enable to calculate corelations
Definition: Factory.h:213
void EvaluateAllVariables(DataLoader *loader, TString options="")
Iterates over all MVA input variables and evaluates them.
Definition: Factory.cxx:1334
TString fVerboseLevel
verbose mode
Definition: Factory.h:211
TH1F * EvaluateImportance(DataLoader *loader, VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Evaluate Variable Importance.
Definition: Factory.cxx:2165
virtual ~Factory()
Destructor.
Definition: Factory.cxx:299
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition: Factory.cxx:903
virtual void MakeClass(const TString &datasetname, const TString &methodTitle="") const
Definition: Factory.cxx:1279
MethodBase * BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile)
Adds an already constructed method to be managed by this factory.
Definition: Factory.cxx:497
Bool_t fModelPersistence
the training type
Definition: Factory.h:219
std::map< TString, Double_t > OptimizeAllMethods(TString fomType="ROCIntegral", TString fitType="FitGA")
Iterates through all booked methods and sees if they use parameter tuning and if so.
Definition: Factory.cxx:694
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Private method to generate a ROCCurve instance for a given method.
Definition: Factory.cxx:743
TH1F * EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2284
Types::EAnalysisType fAnalysisType
jobname, used as extension in weight file names
Definition: Factory.h:218
Bool_t HasMethod(const TString &datasetname, const TString &title) const
Checks whether a given method name is defined for a given dataset.
Definition: Factory.cxx:579
TH1F * EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle, const char *theOption="")
Definition: Factory.cxx:2188
void SetVerbose(Bool_t v=kTRUE)
Definition: Factory.cxx:337
TFile * fgTargetFile
Definition: Factory.h:202
IMethod * GetMethod(const TString &datasetname, const TString &title) const
Returns pointer to MVA that corresponds to given method title.
Definition: Factory.cxx:561
void DeleteAllMethods(void)
Delete methods.
Definition: Factory.cxx:317
TString fTransformations
option string given by construction (presently only "V")
Definition: Factory.h:209
void Greetings()
Print welcome message.
Definition: Factory.cxx:289
TMultiGraph * GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass)
Generate a collection of graphs, for all methods for a given class.
Definition: Factory.cxx:972
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
virtual void PrintHelpMessage() const =0
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)=0
virtual void MakeClass(const TString &classFileName=TString("")) const =0
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void SetSilentFile(Bool_t status)
Definition: MethodBase.h:377
void SetWeightFileDir(TString fileDir)
set directory of weight file
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:979
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:598
TString GetMethodTypeName() const
Definition: MethodBase.h:331
Bool_t DoMulticlass() const
Definition: MethodBase.h:439
virtual Double_t GetSignificance() const
compute significance of mean difference
const char * GetName() const
Definition: MethodBase.h:333
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:437
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:408
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetAnalysisType(Types::EAnalysisType type)
Definition: MethodBase.h:436
const TString & GetMethodName() const
Definition: MethodBase.h:330
Bool_t DoRegression() const
Definition: MethodBase.h:438
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:425
virtual Double_t GetTrainingEfficiency(const TString &)
DataSetInfo & DataInfo() const
Definition: MethodBase.h:409
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void ReadStateFromFile()
Function to write options and weights to file.
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:625
DataSetInfo & fDataSetInfo
Definition: MethodBase.h:605
Types::EMVA GetMethodType() const
Definition: MethodBase.h:332
void SetFile(TFile *file)
Definition: MethodBase.h:374
DataSet * Data() const
Definition: MethodBase.h:408
void SetModelPersistence(Bool_t status)
Definition: MethodBase.h:381
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:435
Class for boosting a TMVA method.
Definition: MethodBoost.h:58
void SetBoostedMethodName(TString methodName)
Definition: MethodBoost.h:86
DataSetManager * fDataSetManager
Definition: MethodBoost.h:193
Class for categorizing the phase space.
DataSetManager * fDataSetManager
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
void SetMinType(EMsgType minType)
Definition: MsgLogger.h:72
void SetSource(const std::string &source)
Definition: MsgLogger.h:70
static void InhibitOutput()
Definition: MsgLogger.cxx:74
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
Definition: ROCCurve.cxx:220
Double_t GetROCIntegral(const UInt_t points=41)
Calculates the ROC integral (AUC)
Definition: ROCCurve.cxx:251
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
Definition: ROCCurve.cxx:277
Ranking for variables in method (implementation)
Definition: Ranking.h:48
virtual void Print() const
get maximum length of variable names
Definition: Ranking.cxx:111
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Definition: Results.h:57
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:898
void UsefulSortDescending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:575
void ROOTVersionMessage(MsgLogger &logger)
prints the ROOT release number and date
Definition: Tools.cxx:1336
std::vector< TString > SplitString(const TString &theOpt, const char separator) const
splits the option string at 'separator' and fills the list 'splitV' with the primitive strings
Definition: Tools.cxx:1210
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:839
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition: Tools.cxx:335
@ kHtmlLink
Definition: Tools.h:214
void TMVACitation(MsgLogger &logger, ECitation citType=kPlainText)
kinds of TMVA citation
Definition: Tools.cxx:1452
void TMVAVersionMessage(MsgLogger &logger)
prints the TMVA release number and date
Definition: Tools.cxx:1327
void TMVAWelcomeMessage()
direct output, eg, when starting ROOT session -> no use of Logger here
Definition: Tools.cxx:1313
void UsefulSortAscending(std::vector< std::vector< Double_t > > &, std::vector< TString > *vs=0)
sort 2D vector (AND in parallel a TString vector) in such a way that the "first vector is sorted" and...
Definition: Tools.cxx:549
Class that contains all the data information.
void PrintVariableRanking() const
prints ranking of input variables
Singleton class for Global types used by TMVA.
Definition: Types.h:73
static Types & Instance()
the the single instance of "Types" if existing already, or create it (Singleton)
Definition: Types.cxx:70
@ kCategory
Definition: Types.h:99
@ kCuts
Definition: Types.h:80
EAnalysisType
Definition: Types.h:127
@ kMulticlass
Definition: Types.h:130
@ kNoAnalysisType
Definition: Types.h:131
@ kClassification
Definition: Types.h:128
@ kMaxAnalysisType
Definition: Types.h:132
@ kRegression
Definition: Types.h:129
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
const TString & GetLabel() const
Definition: VariableInfo.h:59
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:36
TList * GetListOfGraphs() const
Definition: TMultiGraph.h:70
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TH1F * GetHistogram()
Returns a pointer to the histogram used to draw the axis.
virtual void Draw(Option_t *chopt="")
Draw this multigraph with its current attributes.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
TString fName
Definition: TNamed.h:32
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition: TNamed.cxx:74
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
@ kOverwrite
overwrite existing object with same name
Definition: TObject.h:88
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
Definition: TPad.cxx:493
virtual void SetGrid(Int_t valuex=1, Int_t valuey=1)
Definition: TPad.h:330
Principal Components Analysis (PCA)
Definition: TPrincipal.h:20
virtual void AddRow(const Double_t *x)
Add a data point and update the covariance matrix.
Definition: TPrincipal.cxx:410
const TMatrixD * GetCovarianceMatrix() const
Definition: TPrincipal.h:58
virtual void MakePrincipals()
Perform the principal components analysis.
Definition: TPrincipal.cxx:869
Random number generator class based on M.
Definition: TRandom3.h:27
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1125
int CompareTo(const char *cs, ECaseCompare cmp=kExact) const
Compare a string to char *cs2.
Definition: TString.cxx:418
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition: TString.h:610
Bool_t IsNull() const
Definition: TString.h:402
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
void SetOptStat(Int_t stat=1)
The type of information printed in the histogram statistics box can be selected via the parameter mod...
Definition: TStyle.cxx:1590
void SetTitleXOffset(Float_t offset=1)
Definition: TStyle.h:390
virtual int MakeDirectory(const char *name)
Make a directory.
Definition: TSystem.cxx:823
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TTree.cxx:9595
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
const Int_t n
Definition: legend1.C:16
RPY_EXPORTED TCppMethod_t GetMethod(TCppScope_t scope, TCppIndex_t imeth)
static constexpr double s
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:335
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Bool_t IsNaN(Double_t x)
Definition: TMath.h:882
Double_t Log(Double_t x)
Definition: TMath.h:750
Definition: graph.py:1
auto * m
Definition: textangle.C:8
#define VIBITS
Definition: Factory.cxx:106
static long int sum(long int i)
Definition: Factory.cxx:2275
const Int_t MinNoTrainingEvents
Definition: Factory.cxx:98
#define READXML
Definition: Factory.cxx:103