Logo ROOT  
Reference Guide
MethodBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
16 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * U. of Victoria, Canada *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 * *
32 **********************************************************************************/
33
34/*! \class TMVA::MethodBase
35\ingroup TMVA
36
37 Virtual base Class for all MVA method
38
39 MethodBase hosts several specific evaluation methods.
40
41 The kind of MVA that provides optimal performance in an analysis strongly
42 depends on the particular application. The evaluation factory provides a
43 number of numerical benchmark results to directly assess the performance
44 of the MVA training on the independent test sample. These are:
45
46 - The _signal efficiency_ at three representative background efficiencies
47 (which is 1 &minus; rejection).
48 - The _significance_ of an MVA estimator, defined by the difference
49 between the MVA mean values for signal and background, divided by the
50 quadratic sum of their root mean squares.
51 - The _separation_ of an MVA _x_, defined by the integral
52 \f[
53 \frac{1}{2} \int \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx
54 \f]
55 where
56 \f$ S(x) \f$ and \f$ B(x) \f$ are the signal and background distributions,
57 respectively. The separation is zero for identical signal and background MVA
58 shapes, and it is one for disjunctive shapes.
59 - The average, \f$ \int x \mu (S(x)) dx \f$, of the signal \f$ \mu_{transform} \f$.
60 The \f$ \mu_{transform} \f$ of an MVA denotes the transformation that yields
61 a uniform background distribution. In this way, the signal distributions
62 \f$ S(x) \f$ can be directly compared among the various MVAs. The stronger
63 \f$ S(x) \f$ peaks towards one, the better is the discrimination of the MVA.
64 The \f$ \mu_{transform} \f$ is
65 [documented here](http://tel.ccsd.cnrs.fr/documents/archives0/00/00/29/91/index_fr.html).
66
67 The MVA standard output also prints the linear correlation coefficients between
68 signal and background, which can be useful to eliminate variables that exhibit too
69 strong correlations.
70*/
71
72#include "TMVA/MethodBase.h"
73
74#include "TMVA/Config.h"
75#include "TMVA/Configurable.h"
76#include "TMVA/DataSetInfo.h"
77#include "TMVA/DataSet.h"
78#include "TMVA/Factory.h"
79#include "TMVA/IMethod.h"
80#include "TMVA/MsgLogger.h"
81#include "TMVA/PDF.h"
82#include "TMVA/Ranking.h"
83#include "TMVA/DataLoader.h"
84#include "TMVA/Tools.h"
85#include "TMVA/Results.h"
89#include "TMVA/RootFinder.h"
90#include "TMVA/Timer.h"
91#include "TMVA/TSpline1.h"
92#include "TMVA/Types.h"
96#include "TMVA/VariableInfo.h"
100#include "TMVA/Version.h"
101
102#include "TROOT.h"
103#include "TSystem.h"
104#include "TObjString.h"
105#include "TQObject.h"
106#include "TSpline.h"
107#include "TMatrix.h"
108#include "TMath.h"
109#include "TH1F.h"
110#include "TH2F.h"
111#include "TFile.h"
112#include "TGraph.h"
113#include "TXMLEngine.h"
114
115#include <iomanip>
116#include <iostream>
117#include <fstream>
118#include <sstream>
119#include <cstdlib>
120#include <algorithm>
121#include <limits>
122
123
125
126using std::endl;
127using std::atof;
128
129//const Int_t MethodBase_MaxIterations_ = 200;
131
132//const Int_t NBIN_HIST_PLOT = 100;
133const Int_t NBIN_HIST_HIGH = 10000;
134
135#ifdef _WIN32
136/* Disable warning C4355: 'this' : used in base member initializer list */
137#pragma warning ( disable : 4355 )
138#endif
139
140
141#include "TMultiGraph.h"
142
143////////////////////////////////////////////////////////////////////////////////
144/// standard constructor
145
147{
148 fNumGraphs = 0;
149 fIndex = 0;
150}
151
152////////////////////////////////////////////////////////////////////////////////
153/// standard destructor
155{
156 if (fMultiGraph){
157 delete fMultiGraph;
158 fMultiGraph = nullptr;
159 }
160 return;
161}
162
163////////////////////////////////////////////////////////////////////////////////
164/// This function gets some title and it creates a TGraph for every title.
165/// It also sets up the style for every TGraph. All graphs are added to a single TMultiGraph.
166///
167/// \param[in] graphTitles vector of titles
168
169void TMVA::IPythonInteractive::Init(std::vector<TString>& graphTitles)
170{
171 if (fNumGraphs!=0){
172 std::cerr << kERROR << "IPythonInteractive::Init: already initialized..." << std::endl;
173 return;
174 }
175 Int_t color = 2;
176 for(auto& title : graphTitles){
177 fGraphs.push_back( new TGraph() );
178 fGraphs.back()->SetTitle(title);
179 fGraphs.back()->SetName(title);
180 fGraphs.back()->SetFillColor(color);
181 fGraphs.back()->SetLineColor(color);
182 fGraphs.back()->SetMarkerColor(color);
183 fMultiGraph->Add(fGraphs.back());
184 color += 2;
185 fNumGraphs += 1;
186 }
187 return;
188}
189
190////////////////////////////////////////////////////////////////////////////////
191/// This function sets the point number to 0 for all graphs.
192
194{
195 for(Int_t i=0; i<fNumGraphs; i++){
196 fGraphs[i]->Set(0);
197 }
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// This function is used only in 2 TGraph case, and it will add new data points to graphs.
202///
203/// \param[in] x the x coordinate
204/// \param[in] y1 the y coordinate for the first TGraph
205/// \param[in] y2 the y coordinate for the second TGraph
206
208{
209 fGraphs[0]->Set(fIndex+1);
210 fGraphs[1]->Set(fIndex+1);
211 fGraphs[0]->SetPoint(fIndex, x, y1);
212 fGraphs[1]->SetPoint(fIndex, x, y2);
213 fIndex++;
214 return;
215}
216
217////////////////////////////////////////////////////////////////////////////////
218/// This function can add data points to as many TGraphs as we have.
219///
220/// \param[in] dat vector of data points. The dat[0] contains the x coordinate,
221/// dat[1] contains the y coordinate for first TGraph, dat[2] for second, ...
222
223void TMVA::IPythonInteractive::AddPoint(std::vector<Double_t>& dat)
224{
225 for(Int_t i=0; i<fNumGraphs;i++){
226 fGraphs[i]->Set(fIndex+1);
227 fGraphs[i]->SetPoint(fIndex, dat[0], dat[i+1]);
228 }
229 fIndex++;
230 return;
231}
232
233
234////////////////////////////////////////////////////////////////////////////////
235/// standard constructor
236
238 Types::EMVA methodType,
239 const TString& methodTitle,
240 DataSetInfo& dsi,
241 const TString& theOption) :
242 IMethod(),
243 Configurable ( theOption ),
244 fTmpEvent ( 0 ),
245 fRanking ( 0 ),
246 fInputVars ( 0 ),
247 fAnalysisType ( Types::kNoAnalysisType ),
248 fRegressionReturnVal ( 0 ),
249 fMulticlassReturnVal ( 0 ),
250 fDataSetInfo ( dsi ),
251 fSignalReferenceCut ( 0.5 ),
252 fSignalReferenceCutOrientation( 1. ),
253 fVariableTransformType ( Types::kSignal ),
254 fJobName ( jobName ),
255 fMethodName ( methodTitle ),
256 fMethodType ( methodType ),
257 fTestvar ( "" ),
258 fTMVATrainingVersion ( TMVA_VERSION_CODE ),
259 fROOTTrainingVersion ( ROOT_VERSION_CODE ),
260 fConstructedFromWeightFile ( kFALSE ),
261 fBaseDir ( 0 ),
262 fMethodBaseDir ( 0 ),
263 fFile ( 0 ),
264 fSilentFile (kFALSE),
265 fModelPersistence (kTRUE),
266 fWeightFile ( "" ),
267 fEffS ( 0 ),
268 fDefaultPDF ( 0 ),
269 fMVAPdfS ( 0 ),
270 fMVAPdfB ( 0 ),
271 fSplS ( 0 ),
272 fSplB ( 0 ),
273 fSpleffBvsS ( 0 ),
274 fSplTrainS ( 0 ),
275 fSplTrainB ( 0 ),
276 fSplTrainEffBvsS ( 0 ),
277 fVarTransformString ( "None" ),
278 fTransformationPointer ( 0 ),
279 fTransformation ( dsi, methodTitle ),
280 fVerbose ( kFALSE ),
281 fVerbosityLevelString ( "Default" ),
282 fHelp ( kFALSE ),
283 fHasMVAPdfs ( kFALSE ),
284 fIgnoreNegWeightsInTraining( kFALSE ),
285 fSignalClass ( 0 ),
286 fBackgroundClass ( 0 ),
287 fSplRefS ( 0 ),
288 fSplRefB ( 0 ),
289 fSplTrainRefS ( 0 ),
290 fSplTrainRefB ( 0 ),
291 fSetupCompleted (kFALSE)
292{
295
296// // default extension for weight files
297}
298
299////////////////////////////////////////////////////////////////////////////////
300/// constructor used for Testing + Application of the MVA,
301/// only (no training), using given WeightFiles
302
304 DataSetInfo& dsi,
305 const TString& weightFile ) :
306 IMethod(),
307 Configurable(""),
308 fTmpEvent ( 0 ),
309 fRanking ( 0 ),
310 fInputVars ( 0 ),
311 fAnalysisType ( Types::kNoAnalysisType ),
312 fRegressionReturnVal ( 0 ),
313 fMulticlassReturnVal ( 0 ),
314 fDataSetInfo ( dsi ),
315 fSignalReferenceCut ( 0.5 ),
316 fVariableTransformType ( Types::kSignal ),
317 fJobName ( "" ),
318 fMethodName ( "MethodBase" ),
319 fMethodType ( methodType ),
320 fTestvar ( "" ),
321 fTMVATrainingVersion ( 0 ),
322 fROOTTrainingVersion ( 0 ),
323 fConstructedFromWeightFile ( kTRUE ),
324 fBaseDir ( 0 ),
325 fMethodBaseDir ( 0 ),
326 fFile ( 0 ),
327 fSilentFile (kFALSE),
328 fModelPersistence (kTRUE),
329 fWeightFile ( weightFile ),
330 fEffS ( 0 ),
331 fDefaultPDF ( 0 ),
332 fMVAPdfS ( 0 ),
333 fMVAPdfB ( 0 ),
334 fSplS ( 0 ),
335 fSplB ( 0 ),
336 fSpleffBvsS ( 0 ),
337 fSplTrainS ( 0 ),
338 fSplTrainB ( 0 ),
339 fSplTrainEffBvsS ( 0 ),
340 fVarTransformString ( "None" ),
341 fTransformationPointer ( 0 ),
342 fTransformation ( dsi, "" ),
343 fVerbose ( kFALSE ),
344 fVerbosityLevelString ( "Default" ),
345 fHelp ( kFALSE ),
346 fHasMVAPdfs ( kFALSE ),
347 fIgnoreNegWeightsInTraining( kFALSE ),
348 fSignalClass ( 0 ),
349 fBackgroundClass ( 0 ),
350 fSplRefS ( 0 ),
351 fSplRefB ( 0 ),
352 fSplTrainRefS ( 0 ),
353 fSplTrainRefB ( 0 ),
354 fSetupCompleted (kFALSE)
355{
357// // constructor used for Testing + Application of the MVA,
358// // only (no training), using given WeightFiles
359}
360
361////////////////////////////////////////////////////////////////////////////////
362/// destructor
363
365{
366 // destructor
367 if (!fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling destructor of method which got never setup" << Endl;
368
369 // destructor
370 if (fInputVars != 0) { fInputVars->clear(); delete fInputVars; }
371 if (fRanking != 0) delete fRanking;
372
373 // PDFs
374 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
375 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
376 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
377
378 // Splines
379 if (fSplS) { delete fSplS; fSplS = 0; }
380 if (fSplB) { delete fSplB; fSplB = 0; }
381 if (fSpleffBvsS) { delete fSpleffBvsS; fSpleffBvsS = 0; }
382 if (fSplRefS) { delete fSplRefS; fSplRefS = 0; }
383 if (fSplRefB) { delete fSplRefB; fSplRefB = 0; }
384 if (fSplTrainRefS) { delete fSplTrainRefS; fSplTrainRefS = 0; }
385 if (fSplTrainRefB) { delete fSplTrainRefB; fSplTrainRefB = 0; }
386 if (fSplTrainEffBvsS) { delete fSplTrainEffBvsS; fSplTrainEffBvsS = 0; }
387
388 for (Int_t i = 0; i < 2; i++ ) {
389 if (fEventCollections.at(i)) {
390 for (std::vector<Event*>::const_iterator it = fEventCollections.at(i)->begin();
391 it != fEventCollections.at(i)->end(); ++it) {
392 delete (*it);
393 }
394 delete fEventCollections.at(i);
395 fEventCollections.at(i) = 0;
396 }
397 }
398
399 if (fRegressionReturnVal) delete fRegressionReturnVal;
400 if (fMulticlassReturnVal) delete fMulticlassReturnVal;
401}
402
403////////////////////////////////////////////////////////////////////////////////
404/// setup of methods
405
407{
408 // setup of methods
409
410 if (fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling SetupMethod for the second time" << Endl;
411 InitBase();
412 DeclareBaseOptions();
413 Init();
414 DeclareOptions();
415 fSetupCompleted = kTRUE;
416}
417
418////////////////////////////////////////////////////////////////////////////////
419/// process all options
420/// the "CheckForUnusedOptions" is done in an independent call, since it may be overridden by derived class
421/// (sometimes, eg, fitters are used which can only be implemented during training phase)
422
424{
425 ProcessBaseOptions();
426 ProcessOptions();
427}
428
429////////////////////////////////////////////////////////////////////////////////
430/// check may be overridden by derived class
431/// (sometimes, eg, fitters are used which can only be implemented during training phase)
432
434{
435 CheckForUnusedOptions();
436}
437
438////////////////////////////////////////////////////////////////////////////////
439/// default initialization called by all constructors
440
442{
443 SetConfigDescription( "Configuration options for classifier architecture and tuning" );
444
446 fNbinsMVAoutput = gConfig().fVariablePlotting.fNbinsMVAoutput;
447 fNbinsH = NBIN_HIST_HIGH;
448
449 fSplTrainS = 0;
450 fSplTrainB = 0;
451 fSplTrainEffBvsS = 0;
452 fMeanS = -1;
453 fMeanB = -1;
454 fRmsS = -1;
455 fRmsB = -1;
456 fXmin = DBL_MAX;
457 fXmax = -DBL_MAX;
458 fTxtWeightsOnly = kTRUE;
459 fSplRefS = 0;
460 fSplRefB = 0;
461
462 fTrainTime = -1.;
463 fTestTime = -1.;
464
465 fRanking = 0;
466
467 // temporary until the move to DataSet is complete
468 fInputVars = new std::vector<TString>;
469 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
470 fInputVars->push_back(DataInfo().GetVariableInfo(ivar).GetLabel());
471 }
472 fRegressionReturnVal = 0;
473 fMulticlassReturnVal = 0;
474
475 fEventCollections.resize( 2 );
476 fEventCollections.at(0) = 0;
477 fEventCollections.at(1) = 0;
478
479 // retrieve signal and background class index
480 if (DataInfo().GetClassInfo("Signal") != 0) {
481 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
482 }
483 if (DataInfo().GetClassInfo("Background") != 0) {
484 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
485 }
486
487 SetConfigDescription( "Configuration options for MVA method" );
488 SetConfigName( TString("Method") + GetMethodTypeName() );
489}
490
491////////////////////////////////////////////////////////////////////////////////
492/// define the options (their key words) that can be set in the option string
493/// here the options valid for ALL MVA methods are declared.
494///
495/// know options:
496///
497/// - VariableTransform=None,Decorrelated,PCA to use transformed variables
498/// instead of the original ones
499/// - VariableTransformType=Signal,Background which decorrelation matrix to use
500/// in the method. Only the Likelihood
501/// Method can make proper use of independent
502/// transformations of signal and background
503/// - fNbinsMVAPdf = 50 Number of bins used to create a PDF of MVA
504/// - fNsmoothMVAPdf = 2 Number of times a histogram is smoothed before creating the PDF
505/// - fHasMVAPdfs create PDFs for the MVA outputs
506/// - V for Verbose output (!V) for non verbos
507/// - H for Help message
508
510{
511 DeclareOptionRef( fVerbose, "V", "Verbose output (short form of \"VerbosityLevel\" below - overrides the latter one)" );
512
513 DeclareOptionRef( fVerbosityLevelString="Default", "VerbosityLevel", "Verbosity level" );
514 AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
515 AddPreDefVal( TString("Debug") );
516 AddPreDefVal( TString("Verbose") );
517 AddPreDefVal( TString("Info") );
518 AddPreDefVal( TString("Warning") );
519 AddPreDefVal( TString("Error") );
520 AddPreDefVal( TString("Fatal") );
521
522 // If True (default): write all training results (weights) as text files only;
523 // if False: write also in ROOT format (not available for all methods - will abort if not
524 fTxtWeightsOnly = kTRUE; // OBSOLETE !!!
525 fNormalise = kFALSE; // OBSOLETE !!!
526
527 DeclareOptionRef( fVarTransformString, "VarTransform", "List of variable transformations performed before training, e.g., \"D_Background,P_Signal,G,N_AllClasses\" for: \"Decorrelation, PCA-transformation, Gaussianisation, Normalisation, each for the given class of events ('AllClasses' denotes all events of all classes, if no class indication is given, 'All' is assumed)\"" );
528
529 DeclareOptionRef( fHelp, "H", "Print method-specific help message" );
530
531 DeclareOptionRef( fHasMVAPdfs, "CreateMVAPdfs", "Create PDFs for classifier outputs (signal and background)" );
532
533 DeclareOptionRef( fIgnoreNegWeightsInTraining, "IgnoreNegWeightsInTraining",
534 "Events with negative weights are ignored in the training (but are included for testing and performance evaluation)" );
535}
536
537////////////////////////////////////////////////////////////////////////////////
538/// the option string is decoded, for available options see "DeclareOptions"
539
541{
542 if (HasMVAPdfs()) {
543 // setting the default bin num... maybe should be static ? ==> Please no static (JS)
544 // You can't use the logger in the constructor!!! Log() << kINFO << "Create PDFs" << Endl;
545 // reading every PDF's definition and passing the option string to the next one to be read and marked
546 fDefaultPDF = new PDF( TString(GetName())+"_PDF", GetOptions(), "MVAPdf" );
547 fDefaultPDF->DeclareOptions();
548 fDefaultPDF->ParseOptions();
549 fDefaultPDF->ProcessOptions();
550 fMVAPdfB = new PDF( TString(GetName())+"_PDFBkg", fDefaultPDF->GetOptions(), "MVAPdfBkg", fDefaultPDF );
551 fMVAPdfB->DeclareOptions();
552 fMVAPdfB->ParseOptions();
553 fMVAPdfB->ProcessOptions();
554 fMVAPdfS = new PDF( TString(GetName())+"_PDFSig", fMVAPdfB->GetOptions(), "MVAPdfSig", fDefaultPDF );
555 fMVAPdfS->DeclareOptions();
556 fMVAPdfS->ParseOptions();
557 fMVAPdfS->ProcessOptions();
558
559 // the final marked option string is written back to the original methodbase
560 SetOptions( fMVAPdfS->GetOptions() );
561 }
562
563 TMVA::CreateVariableTransforms( fVarTransformString,
564 DataInfo(),
565 GetTransformationHandler(),
566 Log() );
567
568 if (!HasMVAPdfs()) {
569 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
570 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
571 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
572 }
573
574 if (fVerbose) { // overwrites other settings
575 fVerbosityLevelString = TString("Verbose");
576 Log().SetMinType( kVERBOSE );
577 }
578 else if (fVerbosityLevelString == "Debug" ) Log().SetMinType( kDEBUG );
579 else if (fVerbosityLevelString == "Verbose" ) Log().SetMinType( kVERBOSE );
580 else if (fVerbosityLevelString == "Info" ) Log().SetMinType( kINFO );
581 else if (fVerbosityLevelString == "Warning" ) Log().SetMinType( kWARNING );
582 else if (fVerbosityLevelString == "Error" ) Log().SetMinType( kERROR );
583 else if (fVerbosityLevelString == "Fatal" ) Log().SetMinType( kFATAL );
584 else if (fVerbosityLevelString != "Default" ) {
585 Log() << kFATAL << "<ProcessOptions> Verbosity level type '"
586 << fVerbosityLevelString << "' unknown." << Endl;
587 }
588 Event::SetIgnoreNegWeightsInTraining(fIgnoreNegWeightsInTraining);
589}
590
591////////////////////////////////////////////////////////////////////////////////
592/// options that are used ONLY for the READER to ensure backward compatibility
593/// they are hence without any effect (the reader is only reading the training
594/// options that HAD been used at the training of the .xml weight file at hand
595
597{
598 DeclareOptionRef( fNormalise=kFALSE, "Normalise", "Normalise input variables" ); // don't change the default !!!
599 DeclareOptionRef( fUseDecorr=kFALSE, "D", "Use-decorrelated-variables flag" );
600 DeclareOptionRef( fVariableTransformTypeString="Signal", "VarTransformType",
601 "Use signal or background events to derive for variable transformation (the transformation is applied on both types of, course)" );
602 AddPreDefVal( TString("Signal") );
603 AddPreDefVal( TString("Background") );
604 DeclareOptionRef( fTxtWeightsOnly=kTRUE, "TxtWeightFilesOnly", "If True: write all training results (weights) as text files (False: some are written in ROOT format)" );
605 // Why on earth ?? was this here? Was the verbosity level option meant to 'disappear? Not a good idea i think..
606 // DeclareOptionRef( fVerbosityLevelString="Default", "VerboseLevel", "Verbosity level" );
607 // AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
608 // AddPreDefVal( TString("Debug") );
609 // AddPreDefVal( TString("Verbose") );
610 // AddPreDefVal( TString("Info") );
611 // AddPreDefVal( TString("Warning") );
612 // AddPreDefVal( TString("Error") );
613 // AddPreDefVal( TString("Fatal") );
614 DeclareOptionRef( fNbinsMVAPdf = 60, "NbinsMVAPdf", "Number of bins used for the PDFs of classifier outputs" );
615 DeclareOptionRef( fNsmoothMVAPdf = 2, "NsmoothMVAPdf", "Number of smoothing iterations for classifier PDFs" );
616}
617
618
619////////////////////////////////////////////////////////////////////////////////
620/// call the Optimizer with the set of parameters and ranges that
621/// are meant to be tuned.
622
623std::map<TString,Double_t> TMVA::MethodBase::OptimizeTuningParameters(TString /* fomType */ , TString /* fitType */)
624{
625 // this is just a dummy... needs to be implemented for each method
626 // individually (as long as we don't have it automatized via the
627 // configuration string
628
629 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Parameter optimization is not yet implemented for method "
630 << GetName() << Endl;
631 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Currently we need to set hardcoded which parameter is tuned in which ranges"<<Endl;
632
633 std::map<TString,Double_t> tunedParameters;
634 tunedParameters.size(); // just to get rid of "unused" warning
635 return tunedParameters;
636
637}
638
639////////////////////////////////////////////////////////////////////////////////
640/// set the tuning parameters according to the argument
641/// This is just a dummy .. have a look at the MethodBDT how you could
642/// perhaps implement the same thing for the other Classifiers..
643
644void TMVA::MethodBase::SetTuneParameters(std::map<TString,Double_t> /* tuneParameters */)
645{
646}
647
648////////////////////////////////////////////////////////////////////////////////
649
651{
652 Data()->SetCurrentType(Types::kTraining);
653 Event::SetIsTraining(kTRUE); // used to set negative event weights to zero if chosen to do so
654
655 // train the MVA method
656 if (Help()) PrintHelpMessage();
657
658 // all histograms should be created in the method's subdirectory
659 if(!IsSilentFile()) BaseDir()->cd();
660
661 // once calculate all the transformation (e.g. the sequence of Decorr:Gauss:Decorr)
662 // needed for this classifier
663 GetTransformationHandler().CalcTransformations(Data()->GetEventCollection());
664
665 // call training of derived MVA
666 Log() << kDEBUG //<<Form("\tDataset[%s] : ",DataInfo().GetName())
667 << "Begin training" << Endl;
668 Long64_t nEvents = Data()->GetNEvents();
669 Timer traintimer( nEvents, GetName(), kTRUE );
670 Train();
671 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName()
672 << "\tEnd of training " << Endl;
673 SetTrainTime(traintimer.ElapsedSeconds());
674 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
675 << "Elapsed time for training with " << nEvents << " events: "
676 << traintimer.GetElapsedTime() << " " << Endl;
677
678 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
679 << "\tCreate MVA output for ";
680
681 // create PDFs for the signal and background MVA distributions (if required)
682 if (DoMulticlass()) {
683 Log() <<Form("[%s] : ",DataInfo().GetName())<< "Multiclass classification on training sample" << Endl;
684 AddMulticlassOutput(Types::kTraining);
685 }
686 else if (!DoRegression()) {
687
688 Log() <<Form("[%s] : ",DataInfo().GetName())<< "classification on training sample" << Endl;
689 AddClassifierOutput(Types::kTraining);
690 if (HasMVAPdfs()) {
691 CreateMVAPdfs();
692 AddClassifierOutputProb(Types::kTraining);
693 }
694
695 } else {
696
697 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "regression on training sample" << Endl;
698 AddRegressionOutput( Types::kTraining );
699
700 if (HasMVAPdfs() ) {
701 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create PDFs" << Endl;
702 CreateMVAPdfs();
703 }
704 }
705
706 // write the current MVA state into stream
707 // produced are one text file and one ROOT file
708 if (fModelPersistence ) WriteStateToFile();
709
710 // produce standalone make class (presently only supported for classification)
711 if ((!DoRegression()) && (fModelPersistence)) MakeClass();
712
713 // write additional monitoring histograms to main target file (not the weight file)
714 // again, make sure the histograms go into the method's subdirectory
715 if(!IsSilentFile())
716 {
717 BaseDir()->cd();
718 WriteMonitoringHistosToFile();
719 }
720}
721
722////////////////////////////////////////////////////////////////////////////////
723
725{
726 if (!DoRegression()) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Trying to use GetRegressionDeviation() with a classification job" << Endl;
727 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
729 bool truncate = false;
730 TH1F* h1 = regRes->QuadraticDeviation( tgtNum , truncate, 1.);
731 stddev = sqrt(h1->GetMean());
732 truncate = true;
733 Double_t yq[1], xq[]={0.9};
734 h1->GetQuantiles(1,yq,xq);
735 TH1F* h2 = regRes->QuadraticDeviation( tgtNum , truncate, yq[0]);
736 stddev90Percent = sqrt(h2->GetMean());
737 delete h1;
738 delete h2;
739}
740
741////////////////////////////////////////////////////////////////////////////////
742/// prepare tree branch with the method's discriminating variable
743
745{
746 Data()->SetCurrentType(type);
747
748 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
749
750 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), type, Types::kRegression);
751
752 Long64_t nEvents = Data()->GetNEvents();
753
754 // use timer
755 Timer timer( nEvents, GetName(), kTRUE );
756 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
757 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
758
759 regRes->Resize( nEvents );
760
761 // Drawing the progress bar every event was causing a huge slowdown in the evaluation time
762 // So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
763
764 Int_t totalProgressDraws = 100; // total number of times to update the progress bar
765 Int_t drawProgressEvery = 1; // draw every nth event such that we have a total of totalProgressDraws
766 if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
767
768 for (Int_t ievt=0; ievt<nEvents; ievt++) {
769
770 Data()->SetCurrentEvent(ievt);
771 std::vector< Float_t > vals = GetRegressionValues();
772 regRes->SetValue( vals, ievt );
773
774 // Only draw the progress bar once in a while, doing this every event causes the evaluation to be ridiculously slow
775 if(ievt % drawProgressEvery == 0 || ievt==nEvents-1) timer.DrawProgressBar( ievt );
776 }
777
778 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
779 << "Elapsed time for evaluation of " << nEvents << " events: "
780 << timer.GetElapsedTime() << " " << Endl;
781
782 // store time used for testing
784 SetTestTime(timer.ElapsedSeconds());
785
786 TString histNamePrefix(GetTestvarName());
787 histNamePrefix += (type==Types::kTraining?"train":"test");
788 regRes->CreateDeviationHistograms( histNamePrefix );
789}
790
791////////////////////////////////////////////////////////////////////////////////
792/// prepare tree branch with the method's discriminating variable
793
795{
796 Data()->SetCurrentType(type);
797
798 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
799
800 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
801 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in AddMulticlassOutput, exiting."<<Endl;
802
803 Long64_t nEvents = Data()->GetNEvents();
804
805 // use timer
806 Timer timer( nEvents, GetName(), kTRUE );
807
808 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Multiclass evaluation of " << GetMethodName() << " on "
809 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
810
811 resMulticlass->Resize( nEvents );
812 for (Int_t ievt=0; ievt<nEvents; ievt++) {
813 Data()->SetCurrentEvent(ievt);
814 std::vector< Float_t > vals = GetMulticlassValues();
815 resMulticlass->SetValue( vals, ievt );
816 timer.DrawProgressBar( ievt );
817 }
818
819 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
820 << "Elapsed time for evaluation of " << nEvents << " events: "
821 << timer.GetElapsedTime() << " " << Endl;
822
823 // store time used for testing
825 SetTestTime(timer.ElapsedSeconds());
826
827 TString histNamePrefix(GetTestvarName());
828 histNamePrefix += (type==Types::kTraining?"_Train":"_Test");
829
830 resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbinsMVAoutput, fNbinsH );
831 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefix);
832}
833
834////////////////////////////////////////////////////////////////////////////////
835
836void TMVA::MethodBase::NoErrorCalc(Double_t* const err, Double_t* const errUpper) {
837 if (err) *err=-1;
838 if (errUpper) *errUpper=-1;
839}
840
841////////////////////////////////////////////////////////////////////////////////
842
843Double_t TMVA::MethodBase::GetMvaValue( const Event* const ev, Double_t* err, Double_t* errUpper ) {
844 fTmpEvent = ev;
845 Double_t val = GetMvaValue(err, errUpper);
846 fTmpEvent = 0;
847 return val;
848}
849
850////////////////////////////////////////////////////////////////////////////////
851/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
852/// for a quick determination if an event would be selected as signal or background
853
855 return GetMvaValue()*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
856}
857////////////////////////////////////////////////////////////////////////////////
858/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
859/// for a quick determination if an event with this mva output value would be selected as signal or background
860
862 return mvaVal*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
863}
864
865////////////////////////////////////////////////////////////////////////////////
866/// prepare tree branch with the method's discriminating variable
867
869{
870 Data()->SetCurrentType(type);
871
872 ResultsClassification* clRes =
874
875 Long64_t nEvents = Data()->GetNEvents();
876 clRes->Resize( nEvents );
877
878 // use timer
879 Timer timer( nEvents, GetName(), kTRUE );
880 std::vector<Double_t> mvaValues = GetMvaValues(0, nEvents, true);
881
882 // store time used for testing
884 SetTestTime(timer.ElapsedSeconds());
885
886 // load mva values and type to results object
887 for (Int_t ievt = 0; ievt < nEvents; ievt++) {
888 // note we do not need the trasformed event to get the signal/background information
889 // by calling Data()->GetEvent instead of this->GetEvent we access the untransformed one
890 auto ev = Data()->GetEvent(ievt);
891 clRes->SetValue(mvaValues[ievt], ievt, DataInfo().IsSignal(ev));
892 }
893}
894
895////////////////////////////////////////////////////////////////////////////////
896/// get all the MVA values for the events of the current Data type
897std::vector<Double_t> TMVA::MethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
898{
899
900 Long64_t nEvents = Data()->GetNEvents();
901 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
902 if (firstEvt < 0) firstEvt = 0;
903 std::vector<Double_t> values(lastEvt-firstEvt);
904 // log in case of looping on all the events
905 nEvents = values.size();
906
907 // use timer
908 Timer timer( nEvents, GetName(), kTRUE );
909
910 if (logProgress)
911 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
912 << "Evaluation of " << GetMethodName() << " on "
913 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
914 << " sample (" << nEvents << " events)" << Endl;
915
916 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
917 Data()->SetCurrentEvent(ievt);
918 values[ievt] = GetMvaValue();
919
920 // print progress
921 if (logProgress) {
922 Int_t modulo = Int_t(nEvents/100);
923 if (modulo <= 0 ) modulo = 1;
924 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
925 }
926 }
927 if (logProgress) {
928 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
929 << "Elapsed time for evaluation of " << nEvents << " events: "
930 << timer.GetElapsedTime() << " " << Endl;
931 }
932
933 return values;
934}
935
936////////////////////////////////////////////////////////////////////////////////
937/// get all the MVA values for the events of the given Data type
938// (this is used by Method Category and it does not need to be re-implmented by derived classes )
939std::vector<Double_t> TMVA::MethodBase::GetDataMvaValues(DataSet * data, Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
940{
941 fTmpData = data;
942 auto result = GetMvaValues(firstEvt, lastEvt, logProgress);
943 fTmpData = nullptr;
944 return result;
945}
946
947////////////////////////////////////////////////////////////////////////////////
948/// prepare tree branch with the method's discriminating variable
949
951{
952 Data()->SetCurrentType(type);
953
954 ResultsClassification* mvaProb =
955 (ResultsClassification*)Data()->GetResults(TString("prob_")+GetMethodName(), type, Types::kClassification );
956
957 Long64_t nEvents = Data()->GetNEvents();
958
959 // use timer
960 Timer timer( nEvents, GetName(), kTRUE );
961
962 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
963 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
964
965 mvaProb->Resize( nEvents );
966 for (Int_t ievt=0; ievt<nEvents; ievt++) {
967
968 Data()->SetCurrentEvent(ievt);
969 Float_t proba = ((Float_t)GetProba( GetMvaValue(), 0.5 ));
970 if (proba < 0) break;
971 mvaProb->SetValue( proba, ievt, DataInfo().IsSignal( Data()->GetEvent()) );
972
973 // print progress
974 Int_t modulo = Int_t(nEvents/100);
975 if (modulo <= 0 ) modulo = 1;
976 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
977 }
978
979 Log() << kDEBUG <<Form("Dataset[%s] : ",DataInfo().GetName())
980 << "Elapsed time for evaluation of " << nEvents << " events: "
981 << timer.GetElapsedTime() << " " << Endl;
982}
983
984////////////////////////////////////////////////////////////////////////////////
985/// calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
986///
987/// - bias = average deviation
988/// - dev = average absolute deviation
989/// - rms = rms of deviation
990
992 Double_t& dev, Double_t& devT,
993 Double_t& rms, Double_t& rmsT,
994 Double_t& mInf, Double_t& mInfT,
995 Double_t& corr,
997{
998 Types::ETreeType savedType = Data()->GetCurrentType();
999 Data()->SetCurrentType(type);
1000
1001 bias = 0; biasT = 0; dev = 0; devT = 0; rms = 0; rmsT = 0;
1002 Double_t sumw = 0;
1003 Double_t m1 = 0, m2 = 0, s1 = 0, s2 = 0, s12 = 0; // for correlation
1004 const Int_t nevt = GetNEvents();
1005 Float_t* rV = new Float_t[nevt];
1006 Float_t* tV = new Float_t[nevt];
1007 Float_t* wV = new Float_t[nevt];
1008 Float_t xmin = 1e30, xmax = -1e30;
1009 Log() << kINFO << "Calculate regression for all events" << Endl;
1010 Timer timer( nevt, GetName(), kTRUE );
1011 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1012
1013 const Event* ev = Data()->GetEvent(ievt); // NOTE: need untransformed event here !
1014 Float_t t = ev->GetTarget(0);
1015 Float_t w = ev->GetWeight();
1016 Float_t r = GetRegressionValues()[0];
1017 Float_t d = (r-t);
1018
1019 // find min/max
1022
1023 // store for truncated RMS computation
1024 rV[ievt] = r;
1025 tV[ievt] = t;
1026 wV[ievt] = w;
1027
1028 // compute deviation-squared
1029 sumw += w;
1030 bias += w * d;
1031 dev += w * TMath::Abs(d);
1032 rms += w * d * d;
1033
1034 // compute correlation between target and regression estimate
1035 m1 += t*w; s1 += t*t*w;
1036 m2 += r*w; s2 += r*r*w;
1037 s12 += t*r;
1038 // print progress
1039 Long64_t modulo = Long64_t(nevt / 100);
1040 if (ievt % modulo == 0)
1041 timer.DrawProgressBar(ievt);
1042 }
1043 timer.DrawProgressBar(nevt - 1);
1044 Log() << kINFO << "Elapsed time for evaluation of " << nevt << " events: "
1045 << timer.GetElapsedTime() << " " << Endl;
1046
1047 // standard quantities
1048 bias /= sumw;
1049 dev /= sumw;
1050 rms /= sumw;
1051 rms = TMath::Sqrt(rms - bias*bias);
1052
1053 // correlation
1054 m1 /= sumw;
1055 m2 /= sumw;
1056 corr = s12/sumw - m1*m2;
1057 corr /= TMath::Sqrt( (s1/sumw - m1*m1) * (s2/sumw - m2*m2) );
1058
1059 // create histogram required for computation of mutual information
1060 TH2F* hist = new TH2F( "hist", "hist", 150, xmin, xmax, 100, xmin, xmax );
1061 TH2F* histT = new TH2F( "histT", "histT", 150, xmin, xmax, 100, xmin, xmax );
1062
1063 // compute truncated RMS and fill histogram
1064 Double_t devMax = bias + 2*rms;
1065 Double_t devMin = bias - 2*rms;
1066 sumw = 0;
1067 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1068 Float_t d = (rV[ievt] - tV[ievt]);
1069 hist->Fill( rV[ievt], tV[ievt], wV[ievt] );
1070 if (d >= devMin && d <= devMax) {
1071 sumw += wV[ievt];
1072 biasT += wV[ievt] * d;
1073 devT += wV[ievt] * TMath::Abs(d);
1074 rmsT += wV[ievt] * d * d;
1075 histT->Fill( rV[ievt], tV[ievt], wV[ievt] );
1076 }
1077 }
1078 biasT /= sumw;
1079 devT /= sumw;
1080 rmsT /= sumw;
1081 rmsT = TMath::Sqrt(rmsT - biasT*biasT);
1082 mInf = gTools().GetMutualInformation( *hist );
1083 mInfT = gTools().GetMutualInformation( *histT );
1084
1085 delete hist;
1086 delete histT;
1087
1088 delete [] rV;
1089 delete [] tV;
1090 delete [] wV;
1091
1092 Data()->SetCurrentType(savedType);
1093}
1094
1095
1096////////////////////////////////////////////////////////////////////////////////
1097/// test multiclass classification
1098
1100{
1101 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
1102 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in TestMulticlass, exiting."<<Endl;
1103
1104 // GA evaluation of best cut for sig eff * sig pur. Slow, disabled for now.
1105 // Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for test
1106 // data..." << Endl; for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
1107 // resMulticlass->GetBestMultiClassCuts(icls);
1108 // }
1109
1110 // Create histograms for use in TMVA GUI
1111 TString histNamePrefix(GetTestvarName());
1112 TString histNamePrefixTest{histNamePrefix + "_Test"};
1113 TString histNamePrefixTrain{histNamePrefix + "_Train"};
1114
1115 resMulticlass->CreateMulticlassHistos(histNamePrefixTest, fNbinsMVAoutput, fNbinsH);
1116 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTest);
1117
1118 resMulticlass->CreateMulticlassHistos(histNamePrefixTrain, fNbinsMVAoutput, fNbinsH);
1119 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTrain);
1120}
1121
1122
1123////////////////////////////////////////////////////////////////////////////////
1124/// initialization
1125
1127{
1128 Data()->SetCurrentType(Types::kTesting);
1129
1130 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
1131 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
1132
1133 // sanity checks: tree must exist, and theVar must be in tree
1134 if (0==mvaRes && !(GetMethodTypeName().Contains("Cuts"))) {
1135 Log()<<Form("Dataset[%s] : ",DataInfo().GetName()) << "mvaRes " << mvaRes << " GetMethodTypeName " << GetMethodTypeName()
1136 << " contains " << !(GetMethodTypeName().Contains("Cuts")) << Endl;
1137 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<TestInit> Test variable " << GetTestvarName()
1138 << " not found in tree" << Endl;
1139 }
1140
1141 // basic statistics operations are made in base class
1142 gTools().ComputeStat( GetEventCollection(Types::kTesting), mvaRes->GetValueVector(),
1143 fMeanS, fMeanB, fRmsS, fRmsB, fXmin, fXmax, fSignalClass );
1144
1145 // choose reasonable histogram ranges, by removing outliers
1146 Double_t nrms = 10;
1147 fXmin = TMath::Max( TMath::Min( fMeanS - nrms*fRmsS, fMeanB - nrms*fRmsB ), fXmin );
1148 fXmax = TMath::Min( TMath::Max( fMeanS + nrms*fRmsS, fMeanB + nrms*fRmsB ), fXmax );
1149
1150 // determine cut orientation
1151 fCutOrientation = (fMeanS > fMeanB) ? kPositive : kNegative;
1152
1153 // fill 2 types of histograms for the various analyses
1154 // this one is for actual plotting
1155
1156 Double_t sxmax = fXmax+0.00001;
1157
1158 // classifier response distributions for training sample
1159 // MVA plots used for graphics representation (signal)
1160 TString TestvarName;
1161 if(IsSilentFile())
1162 {
1163 TestvarName=Form("[%s]%s",DataInfo().GetName(),GetTestvarName().Data());
1164 }else
1165 {
1166 TestvarName=GetTestvarName();
1167 }
1168 TH1* mva_s = new TH1D( TestvarName + "_S",TestvarName + "_S", fNbinsMVAoutput, fXmin, sxmax );
1169 TH1* mva_b = new TH1D( TestvarName + "_B",TestvarName + "_B", fNbinsMVAoutput, fXmin, sxmax );
1170 mvaRes->Store(mva_s, "MVA_S");
1171 mvaRes->Store(mva_b, "MVA_B");
1172 mva_s->Sumw2();
1173 mva_b->Sumw2();
1174
1175 TH1* proba_s = 0;
1176 TH1* proba_b = 0;
1177 TH1* rarity_s = 0;
1178 TH1* rarity_b = 0;
1179 if (HasMVAPdfs()) {
1180 // P(MVA) plots used for graphics representation
1181 proba_s = new TH1D( TestvarName + "_Proba_S", TestvarName + "_Proba_S", fNbinsMVAoutput, 0.0, 1.0 );
1182 proba_b = new TH1D( TestvarName + "_Proba_B", TestvarName + "_Proba_B", fNbinsMVAoutput, 0.0, 1.0 );
1183 mvaRes->Store(proba_s, "Prob_S");
1184 mvaRes->Store(proba_b, "Prob_B");
1185 proba_s->Sumw2();
1186 proba_b->Sumw2();
1187
1188 // R(MVA) plots used for graphics representation
1189 rarity_s = new TH1D( TestvarName + "_Rarity_S", TestvarName + "_Rarity_S", fNbinsMVAoutput, 0.0, 1.0 );
1190 rarity_b = new TH1D( TestvarName + "_Rarity_B", TestvarName + "_Rarity_B", fNbinsMVAoutput, 0.0, 1.0 );
1191 mvaRes->Store(rarity_s, "Rar_S");
1192 mvaRes->Store(rarity_b, "Rar_B");
1193 rarity_s->Sumw2();
1194 rarity_b->Sumw2();
1195 }
1196
1197 // MVA plots used for efficiency calculations (large number of bins)
1198 TH1* mva_eff_s = new TH1D( TestvarName + "_S_high", TestvarName + "_S_high", fNbinsH, fXmin, sxmax );
1199 TH1* mva_eff_b = new TH1D( TestvarName + "_B_high", TestvarName + "_B_high", fNbinsH, fXmin, sxmax );
1200 mvaRes->Store(mva_eff_s, "MVA_HIGHBIN_S");
1201 mvaRes->Store(mva_eff_b, "MVA_HIGHBIN_B");
1202 mva_eff_s->Sumw2();
1203 mva_eff_b->Sumw2();
1204
1205 // fill the histograms
1206
1207 ResultsClassification* mvaProb = dynamic_cast<ResultsClassification*>
1208 (Data()->GetResults( TString("prob_")+GetMethodName(), Types::kTesting, Types::kMaxAnalysisType ) );
1209
1210 Log() << kHEADER <<Form("[%s] : ",DataInfo().GetName())<< "Loop over test events and fill histograms with classifier response..." << Endl << Endl;
1211 if (mvaProb) Log() << kINFO << "Also filling probability and rarity histograms (on request)..." << Endl;
1212 //std::vector<Bool_t>* mvaResTypes = mvaRes->GetValueVectorTypes();
1213
1214 //LM: this is needed to avoid crashes in ROOCCURVE
1215 if ( mvaRes->GetSize() != GetNEvents() ) {
1216 Log() << kFATAL << TString::Format("Inconsistent result size %lld with number of events %u ", mvaRes->GetSize() , GetNEvents() ) << Endl;
1217 assert(mvaRes->GetSize() == GetNEvents());
1218 }
1219
1220 for (Long64_t ievt=0; ievt<GetNEvents(); ievt++) {
1221
1222 const Event* ev = GetEvent(ievt);
1223 Float_t v = (*mvaRes)[ievt][0];
1224 Float_t w = ev->GetWeight();
1225
1226 if (DataInfo().IsSignal(ev)) {
1227 //mvaResTypes->push_back(kTRUE);
1228 mva_s ->Fill( v, w );
1229 if (mvaProb) {
1230 proba_s->Fill( (*mvaProb)[ievt][0], w );
1231 rarity_s->Fill( GetRarity( v ), w );
1232 }
1233
1234 mva_eff_s ->Fill( v, w );
1235 }
1236 else {
1237 //mvaResTypes->push_back(kFALSE);
1238 mva_b ->Fill( v, w );
1239 if (mvaProb) {
1240 proba_b->Fill( (*mvaProb)[ievt][0], w );
1241 rarity_b->Fill( GetRarity( v ), w );
1242 }
1243 mva_eff_b ->Fill( v, w );
1244 }
1245 }
1246
1247 // uncomment those (and several others if you want unnormalized output
1248 gTools().NormHist( mva_s );
1249 gTools().NormHist( mva_b );
1250 gTools().NormHist( proba_s );
1251 gTools().NormHist( proba_b );
1252 gTools().NormHist( rarity_s );
1253 gTools().NormHist( rarity_b );
1254 gTools().NormHist( mva_eff_s );
1255 gTools().NormHist( mva_eff_b );
1256
1257 // create PDFs from histograms, using default splines, and no additional smoothing
1258 if (fSplS) { delete fSplS; fSplS = 0; }
1259 if (fSplB) { delete fSplB; fSplB = 0; }
1260 fSplS = new PDF( TString(GetName()) + " PDF Sig", mva_s, PDF::kSpline2 );
1261 fSplB = new PDF( TString(GetName()) + " PDF Bkg", mva_b, PDF::kSpline2 );
1262}
1263
1264////////////////////////////////////////////////////////////////////////////////
1265/// general method used in writing the header of the weight files where
1266/// the used variables, variable transformation type etc. is specified
1267
1268void TMVA::MethodBase::WriteStateToStream( std::ostream& tf ) const
1269{
1270 TString prefix = "";
1271 UserGroup_t * userInfo = gSystem->GetUserInfo();
1272
1273 tf << prefix << "#GEN -*-*-*-*-*-*-*-*-*-*-*- general info -*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1274 tf << prefix << "Method : " << GetMethodTypeName() << "::" << GetMethodName() << std::endl;
1275 tf.setf(std::ios::left);
1276 tf << prefix << "TMVA Release : " << std::setw(10) << GetTrainingTMVAVersionString() << " ["
1277 << GetTrainingTMVAVersionCode() << "]" << std::endl;
1278 tf << prefix << "ROOT Release : " << std::setw(10) << GetTrainingROOTVersionString() << " ["
1279 << GetTrainingROOTVersionCode() << "]" << std::endl;
1280 tf << prefix << "Creator : " << userInfo->fUser << std::endl;
1281 tf << prefix << "Date : "; TDatime *d = new TDatime; tf << d->AsString() << std::endl; delete d;
1282 tf << prefix << "Host : " << gSystem->GetBuildNode() << std::endl;
1283 tf << prefix << "Dir : " << gSystem->WorkingDirectory() << std::endl;
1284 tf << prefix << "Training events: " << Data()->GetNTrainingEvents() << std::endl;
1285
1286 TString analysisType(((const_cast<TMVA::MethodBase*>(this)->GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification"));
1287
1288 tf << prefix << "Analysis type : " << "[" << ((GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification") << "]" << std::endl;
1289 tf << prefix << std::endl;
1290
1291 delete userInfo;
1292
1293 // First write all options
1294 tf << prefix << std::endl << prefix << "#OPT -*-*-*-*-*-*-*-*-*-*-*-*- options -*-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1295 WriteOptionsToStream( tf, prefix );
1296 tf << prefix << std::endl;
1297
1298 // Second write variable info
1299 tf << prefix << std::endl << prefix << "#VAR -*-*-*-*-*-*-*-*-*-*-*-* variables *-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1300 WriteVarsToStream( tf, prefix );
1301 tf << prefix << std::endl;
1302}
1303
1304////////////////////////////////////////////////////////////////////////////////
1305/// xml writing
1306
1307void TMVA::MethodBase::AddInfoItem( void* gi, const TString& name, const TString& value) const
1308{
1309 void* it = gTools().AddChild(gi,"Info");
1310 gTools().AddAttr(it,"name", name);
1311 gTools().AddAttr(it,"value", value);
1312}
1313
1314////////////////////////////////////////////////////////////////////////////////
1315
1317 if (analysisType == Types::kRegression) {
1318 AddRegressionOutput( type );
1319 } else if (analysisType == Types::kMulticlass) {
1320 AddMulticlassOutput( type );
1321 } else {
1322 AddClassifierOutput( type );
1323 if (HasMVAPdfs())
1324 AddClassifierOutputProb( type );
1325 }
1326}
1327
1328////////////////////////////////////////////////////////////////////////////////
1329/// general method used in writing the header of the weight files where
1330/// the used variables, variable transformation type etc. is specified
1331
1332void TMVA::MethodBase::WriteStateToXML( void* parent ) const
1333{
1334 if (!parent) return;
1335
1336 UserGroup_t* userInfo = gSystem->GetUserInfo();
1337
1338 void* gi = gTools().AddChild(parent, "GeneralInfo");
1339 AddInfoItem( gi, "TMVA Release", GetTrainingTMVAVersionString() + " [" + gTools().StringFromInt(GetTrainingTMVAVersionCode()) + "]" );
1340 AddInfoItem( gi, "ROOT Release", GetTrainingROOTVersionString() + " [" + gTools().StringFromInt(GetTrainingROOTVersionCode()) + "]");
1341 AddInfoItem( gi, "Creator", userInfo->fUser);
1342 TDatime dt; AddInfoItem( gi, "Date", dt.AsString());
1343 AddInfoItem( gi, "Host", gSystem->GetBuildNode() );
1344 AddInfoItem( gi, "Dir", gSystem->WorkingDirectory());
1345 AddInfoItem( gi, "Training events", gTools().StringFromInt(Data()->GetNTrainingEvents()));
1346 AddInfoItem( gi, "TrainingTime", gTools().StringFromDouble(const_cast<TMVA::MethodBase*>(this)->GetTrainTime()));
1347
1348 Types::EAnalysisType aType = const_cast<TMVA::MethodBase*>(this)->GetAnalysisType();
1349 TString analysisType((aType==Types::kRegression) ? "Regression" :
1350 (aType==Types::kMulticlass ? "Multiclass" : "Classification"));
1351 AddInfoItem( gi, "AnalysisType", analysisType );
1352 delete userInfo;
1353
1354 // write options
1355 AddOptionsXMLTo( parent );
1356
1357 // write variable info
1358 AddVarsXMLTo( parent );
1359
1360 // write spectator info
1361 if (fModelPersistence)
1362 AddSpectatorsXMLTo( parent );
1363
1364 // write class info if in multiclass mode
1365 AddClassesXMLTo(parent);
1366
1367 // write target info if in regression mode
1368 if (DoRegression()) AddTargetsXMLTo(parent);
1369
1370 // write transformations
1371 GetTransformationHandler(false).AddXMLTo( parent );
1372
1373 // write MVA variable distributions
1374 void* pdfs = gTools().AddChild(parent, "MVAPdfs");
1375 if (fMVAPdfS) fMVAPdfS->AddXMLTo(pdfs);
1376 if (fMVAPdfB) fMVAPdfB->AddXMLTo(pdfs);
1377
1378 // write weights
1379 AddWeightsXMLTo( parent );
1380}
1381
1382////////////////////////////////////////////////////////////////////////////////
1383/// write reference MVA distributions (and other information)
1384/// to a ROOT type weight file
1385
1387{
1388 Bool_t addDirStatus = TH1::AddDirectoryStatus();
1389 TH1::AddDirectory( 0 ); // this avoids the binding of the hists in PDF to the current ROOT file
1390 fMVAPdfS = (TMVA::PDF*)rf.Get( "MVA_PDF_Signal" );
1391 fMVAPdfB = (TMVA::PDF*)rf.Get( "MVA_PDF_Background" );
1392
1393 TH1::AddDirectory( addDirStatus );
1394
1395 ReadWeightsFromStream( rf );
1396
1397 SetTestvarName();
1398}
1399
1400////////////////////////////////////////////////////////////////////////////////
1401/// write options and weights to file
1402/// note that each one text file for the main configuration information
1403/// and one ROOT file for ROOT objects are created
1404
1406{
1407 // ---- create the text file
1408 TString tfname( GetWeightFileName() );
1409
1410 // writing xml file
1411 TString xmlfname( tfname ); xmlfname.ReplaceAll( ".txt", ".xml" );
1412 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1413 << "Creating xml weight file: "
1414 << gTools().Color("lightblue") << xmlfname << gTools().Color("reset") << Endl;
1415 void* doc = gTools().xmlengine().NewDoc();
1416 void* rootnode = gTools().AddChild(0,"MethodSetup", "", true);
1417 gTools().xmlengine().DocSetRootElement(doc,rootnode);
1418 gTools().AddAttr(rootnode,"Method", GetMethodTypeName() + "::" + GetMethodName());
1419 WriteStateToXML(rootnode);
1420 gTools().xmlengine().SaveDoc(doc,xmlfname);
1421 gTools().xmlengine().FreeDoc(doc);
1422}
1423
1424////////////////////////////////////////////////////////////////////////////////
1425/// Function to write options and weights to file
1426
1428{
1429 // get the filename
1430
1431 TString tfname(GetWeightFileName());
1432
1433 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1434 << "Reading weight file: "
1435 << gTools().Color("lightblue") << tfname << gTools().Color("reset") << Endl;
1436
1437 if (tfname.EndsWith(".xml") ) {
1438 void* doc = gTools().xmlengine().ParseFile(tfname,gTools().xmlenginebuffersize()); // the default buffer size in TXMLEngine::ParseFile is 100k. Starting with ROOT 5.29 one can set the buffer size, see: http://savannah.cern.ch/bugs/?78864. This might be necessary for large XML files
1439 if (!doc) {
1440 Log() << kFATAL << "Error parsing XML file " << tfname << Endl;
1441 }
1442 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1443 ReadStateFromXML(rootnode);
1444 gTools().xmlengine().FreeDoc(doc);
1445 }
1446 else {
1447 std::filebuf fb;
1448 fb.open(tfname.Data(),std::ios::in);
1449 if (!fb.is_open()) { // file not found --> Error
1450 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ReadStateFromFile> "
1451 << "Unable to open input weight file: " << tfname << Endl;
1452 }
1453 std::istream fin(&fb);
1454 ReadStateFromStream(fin);
1455 fb.close();
1456 }
1457 if (!fTxtWeightsOnly) {
1458 // ---- read the ROOT file
1459 TString rfname( tfname ); rfname.ReplaceAll( ".txt", ".root" );
1460 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Reading root weight file: "
1461 << gTools().Color("lightblue") << rfname << gTools().Color("reset") << Endl;
1462 TFile* rfile = TFile::Open( rfname, "READ" );
1463 ReadStateFromStream( *rfile );
1464 rfile->Close();
1465 }
1466}
1467////////////////////////////////////////////////////////////////////////////////
1468/// for reading from memory
1469
1471 void* doc = gTools().xmlengine().ParseString(xmlstr);
1472 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1473 ReadStateFromXML(rootnode);
1474 gTools().xmlengine().FreeDoc(doc);
1475
1476 return;
1477}
1478
1479////////////////////////////////////////////////////////////////////////////////
1480
1482{
1483
1484 TString fullMethodName;
1485 gTools().ReadAttr( methodNode, "Method", fullMethodName );
1486
1487 fMethodName = fullMethodName(fullMethodName.Index("::")+2,fullMethodName.Length());
1488
1489 // update logger
1490 Log().SetSource( GetName() );
1491 Log() << kDEBUG//<<Form("Dataset[%s] : ",DataInfo().GetName())
1492 << "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1493
1494 // after the method name is read, the testvar can be set
1495 SetTestvarName();
1496
1497 TString nodeName("");
1498 void* ch = gTools().GetChild(methodNode);
1499 while (ch!=0) {
1500 nodeName = TString( gTools().GetName(ch) );
1501
1502 if (nodeName=="GeneralInfo") {
1503 // read analysis type
1504
1505 TString name(""),val("");
1506 void* antypeNode = gTools().GetChild(ch);
1507 while (antypeNode) {
1508 gTools().ReadAttr( antypeNode, "name", name );
1509
1510 if (name == "TrainingTime")
1511 gTools().ReadAttr( antypeNode, "value", fTrainTime );
1512
1513 if (name == "AnalysisType") {
1514 gTools().ReadAttr( antypeNode, "value", val );
1515 val.ToLower();
1516 if (val == "regression" ) SetAnalysisType( Types::kRegression );
1517 else if (val == "classification" ) SetAnalysisType( Types::kClassification );
1518 else if (val == "multiclass" ) SetAnalysisType( Types::kMulticlass );
1519 else Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Analysis type " << val << " is not known." << Endl;
1520 }
1521
1522 if (name == "TMVA Release" || name == "TMVA") {
1523 TString s;
1524 gTools().ReadAttr( antypeNode, "value", s);
1525 fTMVATrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1526 Log() << kDEBUG <<Form("[%s] : ",DataInfo().GetName()) << "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
1527 }
1528
1529 if (name == "ROOT Release" || name == "ROOT") {
1530 TString s;
1531 gTools().ReadAttr( antypeNode, "value", s);
1532 fROOTTrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1533 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
1534 << "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
1535 }
1536 antypeNode = gTools().GetNextChild(antypeNode);
1537 }
1538 }
1539 else if (nodeName=="Options") {
1540 ReadOptionsFromXML(ch);
1541 ParseOptions();
1542
1543 }
1544 else if (nodeName=="Variables") {
1545 ReadVariablesFromXML(ch);
1546 }
1547 else if (nodeName=="Spectators") {
1548 ReadSpectatorsFromXML(ch);
1549 }
1550 else if (nodeName=="Classes") {
1551 if (DataInfo().GetNClasses()==0) ReadClassesFromXML(ch);
1552 }
1553 else if (nodeName=="Targets") {
1554 if (DataInfo().GetNTargets()==0 && DoRegression()) ReadTargetsFromXML(ch);
1555 }
1556 else if (nodeName=="Transformations") {
1557 GetTransformationHandler().ReadFromXML(ch);
1558 }
1559 else if (nodeName=="MVAPdfs") {
1560 TString pdfname;
1561 if (fMVAPdfS) { delete fMVAPdfS; fMVAPdfS=0; }
1562 if (fMVAPdfB) { delete fMVAPdfB; fMVAPdfB=0; }
1563 void* pdfnode = gTools().GetChild(ch);
1564 if (pdfnode) {
1565 gTools().ReadAttr(pdfnode, "Name", pdfname);
1566 fMVAPdfS = new PDF(pdfname);
1567 fMVAPdfS->ReadXML(pdfnode);
1568 pdfnode = gTools().GetNextChild(pdfnode);
1569 gTools().ReadAttr(pdfnode, "Name", pdfname);
1570 fMVAPdfB = new PDF(pdfname);
1571 fMVAPdfB->ReadXML(pdfnode);
1572 }
1573 }
1574 else if (nodeName=="Weights") {
1575 ReadWeightsFromXML(ch);
1576 }
1577 else {
1578 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Unparsed XML node: '" << nodeName << "'" << Endl;
1579 }
1580 ch = gTools().GetNextChild(ch);
1581
1582 }
1583
1584 // update transformation handler
1585 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1586}
1587
1588////////////////////////////////////////////////////////////////////////////////
1589/// read the header from the weight files of the different MVA methods
1590
1592{
1593 char buf[512];
1594
1595 // when reading from stream, we assume the files are produced with TMVA<=397
1596 SetAnalysisType(Types::kClassification);
1597
1598
1599 // first read the method name
1600 GetLine(fin,buf);
1601 while (!TString(buf).BeginsWith("Method")) GetLine(fin,buf);
1602 TString namestr(buf);
1603
1604 TString methodType = namestr(0,namestr.Index("::"));
1605 methodType = methodType(methodType.Last(' '),methodType.Length());
1606 methodType = methodType.Strip(TString::kLeading);
1607
1608 TString methodName = namestr(namestr.Index("::")+2,namestr.Length());
1609 methodName = methodName.Strip(TString::kLeading);
1610 if (methodName == "") methodName = methodType;
1611 fMethodName = methodName;
1612
1613 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1614
1615 // update logger
1616 Log().SetSource( GetName() );
1617
1618 // now the question is whether to read the variables first or the options (well, of course the order
1619 // of writing them needs to agree)
1620 //
1621 // the option "Decorrelation" is needed to decide if the variables we
1622 // read are decorrelated or not
1623 //
1624 // the variables are needed by some methods (TMLP) to build the NN
1625 // which is done in ProcessOptions so for the time being we first Read and Parse the options then
1626 // we read the variables, and then we process the options
1627
1628 // now read all options
1629 GetLine(fin,buf);
1630 while (!TString(buf).BeginsWith("#OPT")) GetLine(fin,buf);
1631 ReadOptionsFromStream(fin);
1632 ParseOptions();
1633
1634 // Now read variable info
1635 fin.getline(buf,512);
1636 while (!TString(buf).BeginsWith("#VAR")) fin.getline(buf,512);
1637 ReadVarsFromStream(fin);
1638
1639 // now we process the options (of the derived class)
1640 ProcessOptions();
1641
1642 if (IsNormalised()) {
1644 GetTransformationHandler().AddTransformation( new VariableNormalizeTransform(DataInfo()), -1 );
1645 norm->BuildTransformationFromVarInfo( DataInfo().GetVariableInfos() );
1646 }
1647 VariableTransformBase *varTrafo(0), *varTrafo2(0);
1648 if ( fVarTransformString == "None") {
1649 if (fUseDecorr)
1650 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1651 } else if ( fVarTransformString == "Decorrelate" ) {
1652 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1653 } else if ( fVarTransformString == "PCA" ) {
1654 varTrafo = GetTransformationHandler().AddTransformation( new VariablePCATransform(DataInfo()), -1 );
1655 } else if ( fVarTransformString == "Uniform" ) {
1656 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo(),"Uniform"), -1 );
1657 } else if ( fVarTransformString == "Gauss" ) {
1658 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1659 } else if ( fVarTransformString == "GaussDecorr" ) {
1660 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1661 varTrafo2 = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1662 } else {
1663 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ProcessOptions> Variable transform '"
1664 << fVarTransformString << "' unknown." << Endl;
1665 }
1666 // Now read decorrelation matrix if available
1667 if (GetTransformationHandler().GetTransformationList().GetSize() > 0) {
1668 fin.getline(buf,512);
1669 while (!TString(buf).BeginsWith("#MAT")) fin.getline(buf,512);
1670 if (varTrafo) {
1671 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1672 varTrafo->ReadTransformationFromStream(fin, trafo );
1673 }
1674 if (varTrafo2) {
1675 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1676 varTrafo2->ReadTransformationFromStream(fin, trafo );
1677 }
1678 }
1679
1680
1681 if (HasMVAPdfs()) {
1682 // Now read the MVA PDFs
1683 fin.getline(buf,512);
1684 while (!TString(buf).BeginsWith("#MVAPDFS")) fin.getline(buf,512);
1685 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
1686 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
1687 fMVAPdfS = new PDF(TString(GetName()) + " MVA PDF Sig");
1688 fMVAPdfB = new PDF(TString(GetName()) + " MVA PDF Bkg");
1689 fMVAPdfS->SetReadingVersion( GetTrainingTMVAVersionCode() );
1690 fMVAPdfB->SetReadingVersion( GetTrainingTMVAVersionCode() );
1691
1692 fin >> *fMVAPdfS;
1693 fin >> *fMVAPdfB;
1694 }
1695
1696 // Now read weights
1697 fin.getline(buf,512);
1698 while (!TString(buf).BeginsWith("#WGT")) fin.getline(buf,512);
1699 fin.getline(buf,512);
1700 ReadWeightsFromStream( fin );;
1701
1702 // update transformation handler
1703 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1704
1705}
1706
1707////////////////////////////////////////////////////////////////////////////////
1708/// write the list of variables (name, min, max) for a given data
1709/// transformation method to the stream
1710
1711void TMVA::MethodBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const
1712{
1713 o << prefix << "NVar " << DataInfo().GetNVariables() << std::endl;
1714 std::vector<VariableInfo>::const_iterator varIt = DataInfo().GetVariableInfos().begin();
1715 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1716 o << prefix << "NSpec " << DataInfo().GetNSpectators() << std::endl;
1717 varIt = DataInfo().GetSpectatorInfos().begin();
1718 for (; varIt!=DataInfo().GetSpectatorInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1719}
1720
1721////////////////////////////////////////////////////////////////////////////////
1722/// Read the variables (name, min, max) for a given data
1723/// transformation method from the stream. In the stream we only
1724/// expect the limits which will be set
1725
1727{
1728 TString dummy;
1729 UInt_t readNVar;
1730 istr >> dummy >> readNVar;
1731
1732 if (readNVar!=DataInfo().GetNVariables()) {
1733 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1734 << " while there are " << readNVar << " variables declared in the file"
1735 << Endl;
1736 }
1737
1738 // we want to make sure all variables are read in the order they are defined
1739 VariableInfo varInfo;
1740 std::vector<VariableInfo>::iterator varIt = DataInfo().GetVariableInfos().begin();
1741 int varIdx = 0;
1742 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt, ++varIdx) {
1743 varInfo.ReadFromStream(istr);
1744 if (varIt->GetExpression() == varInfo.GetExpression()) {
1745 varInfo.SetExternalLink((*varIt).GetExternalLink());
1746 (*varIt) = varInfo;
1747 }
1748 else {
1749 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVarsFromStream>" << Endl;
1750 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1751 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
1752 Log() << kINFO << "the correct working of the method):" << Endl;
1753 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
1754 Log() << kINFO << " var #" << varIdx <<" declared in file : " << varInfo.GetExpression() << Endl;
1755 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1756 }
1757 }
1758}
1759
1760////////////////////////////////////////////////////////////////////////////////
1761/// write variable info to XML
1762
1763void TMVA::MethodBase::AddVarsXMLTo( void* parent ) const
1764{
1765 void* vars = gTools().AddChild(parent, "Variables");
1766 gTools().AddAttr( vars, "NVar", gTools().StringFromInt(DataInfo().GetNVariables()) );
1767
1768 for (UInt_t idx=0; idx<DataInfo().GetVariableInfos().size(); idx++) {
1769 VariableInfo& vi = DataInfo().GetVariableInfos()[idx];
1770 void* var = gTools().AddChild( vars, "Variable" );
1771 gTools().AddAttr( var, "VarIndex", idx );
1772 vi.AddToXML( var );
1773 }
1774}
1775
1776////////////////////////////////////////////////////////////////////////////////
1777/// write spectator info to XML
1778
1780{
1781 void* specs = gTools().AddChild(parent, "Spectators");
1782
1783 UInt_t writeIdx=0;
1784 for (UInt_t idx=0; idx<DataInfo().GetSpectatorInfos().size(); idx++) {
1785
1786 VariableInfo& vi = DataInfo().GetSpectatorInfos()[idx];
1787
1788 // we do not want to write spectators that are category-cuts,
1789 // except if the method is the category method and the spectators belong to it
1790 if (vi.GetVarType()=='C') continue;
1791
1792 void* spec = gTools().AddChild( specs, "Spectator" );
1793 gTools().AddAttr( spec, "SpecIndex", writeIdx++ );
1794 vi.AddToXML( spec );
1795 }
1796 gTools().AddAttr( specs, "NSpec", gTools().StringFromInt(writeIdx) );
1797}
1798
1799////////////////////////////////////////////////////////////////////////////////
1800/// write class info to XML
1801
1802void TMVA::MethodBase::AddClassesXMLTo( void* parent ) const
1803{
1804 UInt_t nClasses=DataInfo().GetNClasses();
1805
1806 void* classes = gTools().AddChild(parent, "Classes");
1807 gTools().AddAttr( classes, "NClass", nClasses );
1808
1809 for (UInt_t iCls=0; iCls<nClasses; ++iCls) {
1810 ClassInfo *classInfo=DataInfo().GetClassInfo (iCls);
1811 TString className =classInfo->GetName();
1812 UInt_t classNumber=classInfo->GetNumber();
1813
1814 void* classNode=gTools().AddChild(classes, "Class");
1815 gTools().AddAttr( classNode, "Name", className );
1816 gTools().AddAttr( classNode, "Index", classNumber );
1817 }
1818}
1819////////////////////////////////////////////////////////////////////////////////
1820/// write target info to XML
1821
1822void TMVA::MethodBase::AddTargetsXMLTo( void* parent ) const
1823{
1824 void* targets = gTools().AddChild(parent, "Targets");
1825 gTools().AddAttr( targets, "NTrgt", gTools().StringFromInt(DataInfo().GetNTargets()) );
1826
1827 for (UInt_t idx=0; idx<DataInfo().GetTargetInfos().size(); idx++) {
1828 VariableInfo& vi = DataInfo().GetTargetInfos()[idx];
1829 void* tar = gTools().AddChild( targets, "Target" );
1830 gTools().AddAttr( tar, "TargetIndex", idx );
1831 vi.AddToXML( tar );
1832 }
1833}
1834
1835////////////////////////////////////////////////////////////////////////////////
1836/// read variable info from XML
1837
1839{
1840 UInt_t readNVar;
1841 gTools().ReadAttr( varnode, "NVar", readNVar);
1842
1843 if (readNVar!=DataInfo().GetNVariables()) {
1844 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1845 << " while there are " << readNVar << " variables declared in the file"
1846 << Endl;
1847 }
1848
1849 // we want to make sure all variables are read in the order they are defined
1850 VariableInfo readVarInfo, existingVarInfo;
1851 int varIdx = 0;
1852 void* ch = gTools().GetChild(varnode);
1853 while (ch) {
1854 gTools().ReadAttr( ch, "VarIndex", varIdx);
1855 existingVarInfo = DataInfo().GetVariableInfos()[varIdx];
1856 readVarInfo.ReadFromXML(ch);
1857
1858 if (existingVarInfo.GetExpression() == readVarInfo.GetExpression()) {
1859 readVarInfo.SetExternalLink(existingVarInfo.GetExternalLink());
1860 existingVarInfo = readVarInfo;
1861 }
1862 else {
1863 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVariablesFromXML>" << Endl;
1864 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1865 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1866 Log() << kINFO << "correct working of the method):" << Endl;
1867 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << existingVarInfo.GetExpression() << Endl;
1868 Log() << kINFO << " var #" << varIdx <<" declared in file : " << readVarInfo.GetExpression() << Endl;
1869 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1870 }
1871 ch = gTools().GetNextChild(ch);
1872 }
1873}
1874
1875////////////////////////////////////////////////////////////////////////////////
1876/// read spectator info from XML
1877
1879{
1880 UInt_t readNSpec;
1881 gTools().ReadAttr( specnode, "NSpec", readNSpec);
1882
1883 if (readNSpec!=DataInfo().GetNSpectators(kFALSE)) {
1884 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "You declared "<< DataInfo().GetNSpectators(kFALSE) << " spectators in the Reader"
1885 << " while there are " << readNSpec << " spectators declared in the file"
1886 << Endl;
1887 }
1888
1889 // we want to make sure all variables are read in the order they are defined
1890 VariableInfo readSpecInfo, existingSpecInfo;
1891 int specIdx = 0;
1892 void* ch = gTools().GetChild(specnode);
1893 while (ch) {
1894 gTools().ReadAttr( ch, "SpecIndex", specIdx);
1895 existingSpecInfo = DataInfo().GetSpectatorInfos()[specIdx];
1896 readSpecInfo.ReadFromXML(ch);
1897
1898 if (existingSpecInfo.GetExpression() == readSpecInfo.GetExpression()) {
1899 readSpecInfo.SetExternalLink(existingSpecInfo.GetExternalLink());
1900 existingSpecInfo = readSpecInfo;
1901 }
1902 else {
1903 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadSpectatorsFromXML>" << Endl;
1904 Log() << kINFO << "The definition (or the order) of the spectators found in the input file is" << Endl;
1905 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1906 Log() << kINFO << "correct working of the method):" << Endl;
1907 Log() << kINFO << " spec #" << specIdx <<" declared in Reader: " << existingSpecInfo.GetExpression() << Endl;
1908 Log() << kINFO << " spec #" << specIdx <<" declared in file : " << readSpecInfo.GetExpression() << Endl;
1909 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1910 }
1911 ch = gTools().GetNextChild(ch);
1912 }
1913}
1914
1915////////////////////////////////////////////////////////////////////////////////
1916/// read number of classes from XML
1917
1919{
1920 UInt_t readNCls;
1921 // coverity[tainted_data_argument]
1922 gTools().ReadAttr( clsnode, "NClass", readNCls);
1923
1924 TString className="";
1925 UInt_t classIndex=0;
1926 void* ch = gTools().GetChild(clsnode);
1927 if (!ch) {
1928 for (UInt_t icls = 0; icls<readNCls;++icls) {
1929 TString classname = Form("class%i",icls);
1930 DataInfo().AddClass(classname);
1931
1932 }
1933 }
1934 else{
1935 while (ch) {
1936 gTools().ReadAttr( ch, "Index", classIndex);
1937 gTools().ReadAttr( ch, "Name", className );
1938 DataInfo().AddClass(className);
1939
1940 ch = gTools().GetNextChild(ch);
1941 }
1942 }
1943
1944 // retrieve signal and background class index
1945 if (DataInfo().GetClassInfo("Signal") != 0) {
1946 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
1947 }
1948 else
1949 fSignalClass=0;
1950 if (DataInfo().GetClassInfo("Background") != 0) {
1951 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
1952 }
1953 else
1954 fBackgroundClass=1;
1955}
1956
1957////////////////////////////////////////////////////////////////////////////////
1958/// read target info from XML
1959
1961{
1962 UInt_t readNTar;
1963 gTools().ReadAttr( tarnode, "NTrgt", readNTar);
1964
1965 int tarIdx = 0;
1966 TString expression;
1967 void* ch = gTools().GetChild(tarnode);
1968 while (ch) {
1969 gTools().ReadAttr( ch, "TargetIndex", tarIdx);
1970 gTools().ReadAttr( ch, "Expression", expression);
1971 DataInfo().AddTarget(expression,"","",0,0);
1972
1973 ch = gTools().GetNextChild(ch);
1974 }
1975}
1976
1977////////////////////////////////////////////////////////////////////////////////
1978/// returns the ROOT directory where info/histograms etc of the
1979/// corresponding MVA method instance are stored
1980
1982{
1983 if (fBaseDir != 0) return fBaseDir;
1984 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodName() << " not set yet --> check if already there.." <<Endl;
1985
1986 if (IsSilentFile()) {
1987 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
1988 << "MethodBase::BaseDir() - No directory exists when running a Method without output file. Enable the "
1989 "output when creating the factory"
1990 << Endl;
1991 }
1992
1993 TDirectory* methodDir = MethodBaseDir();
1994 if (methodDir==0)
1995 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MethodBase::BaseDir() - MethodBaseDir() return a NULL pointer!" << Endl;
1996
1997 TString defaultDir = GetMethodName();
1998 TDirectory *sdir = methodDir->GetDirectory(defaultDir.Data());
1999 if(!sdir)
2000 {
2001 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " does not exist yet--> created it" <<Endl;
2002 sdir = methodDir->mkdir(defaultDir);
2003 sdir->cd();
2004 // write weight file name into target file
2005 if (fModelPersistence) {
2006 TObjString wfilePath( gSystem->WorkingDirectory() );
2007 TObjString wfileName( GetWeightFileName() );
2008 wfilePath.Write( "TrainingPath" );
2009 wfileName.Write( "WeightFileName" );
2010 }
2011 }
2012
2013 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " existed, return it.." <<Endl;
2014 return sdir;
2015}
2016
2017////////////////////////////////////////////////////////////////////////////////
2018/// returns the ROOT directory where all instances of the
2019/// corresponding MVA method are stored
2020
2022{
2023 if (fMethodBaseDir != 0) {
2024 return fMethodBaseDir;
2025 }
2026
2027 const char *datasetName = DataInfo().GetName();
2028
2029 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodTypeName()
2030 << " not set yet --> check if already there.." << Endl;
2031
2032 TDirectory *factoryBaseDir = GetFile();
2033 if (!factoryBaseDir) return nullptr;
2034 fMethodBaseDir = factoryBaseDir->GetDirectory(datasetName);
2035 if (!fMethodBaseDir) {
2036 fMethodBaseDir = factoryBaseDir->mkdir(datasetName, Form("Base directory for dataset %s", datasetName));
2037 if (!fMethodBaseDir) {
2038 Log() << kFATAL << "Can not create dir " << datasetName;
2039 }
2040 }
2041 TString methodTypeDir = Form("Method_%s", GetMethodTypeName().Data());
2042 fMethodBaseDir = fMethodBaseDir->GetDirectory(methodTypeDir.Data());
2043
2044 if (!fMethodBaseDir) {
2045 TDirectory *datasetDir = factoryBaseDir->GetDirectory(datasetName);
2046 TString methodTypeDirHelpStr = Form("Directory for all %s methods", GetMethodTypeName().Data());
2047 fMethodBaseDir = datasetDir->mkdir(methodTypeDir.Data(), methodTypeDirHelpStr);
2048 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodName()
2049 << " does not exist yet--> created it" << Endl;
2050 }
2051
2052 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName)
2053 << "Return from MethodBaseDir() after creating base directory " << Endl;
2054 return fMethodBaseDir;
2055}
2056
2057////////////////////////////////////////////////////////////////////////////////
2058/// set directory of weight file
2059
2061{
2062 fFileDir = fileDir;
2063 gSystem->mkdir( fFileDir, kTRUE );
2064}
2065
2066////////////////////////////////////////////////////////////////////////////////
2067/// set the weight file name (depreciated)
2068
2070{
2071 fWeightFile = theWeightFile;
2072}
2073
2074////////////////////////////////////////////////////////////////////////////////
2075/// retrieve weight file name
2076
2078{
2079 if (fWeightFile!="") return fWeightFile;
2080
2081 // the default consists of
2082 // directory/jobname_methodname_suffix.extension.{root/txt}
2083 TString suffix = "";
2084 TString wFileDir(GetWeightFileDir());
2085 TString wFileName = GetJobName() + "_" + GetMethodName() +
2086 suffix + "." + gConfig().GetIONames().fWeightFileExtension + ".xml";
2087 if (wFileDir.IsNull() ) return wFileName;
2088 // add weight file directory of it is not null
2089 return ( wFileDir + (wFileDir[wFileDir.Length()-1]=='/' ? "" : "/")
2090 + wFileName );
2091}
2092////////////////////////////////////////////////////////////////////////////////
2093/// writes all MVA evaluation histograms to file
2094
2096{
2097 BaseDir()->cd();
2098
2099
2100 // write MVA PDFs to file - if exist
2101 if (0 != fMVAPdfS) {
2102 fMVAPdfS->GetOriginalHist()->Write();
2103 fMVAPdfS->GetSmoothedHist()->Write();
2104 fMVAPdfS->GetPDFHist()->Write();
2105 }
2106 if (0 != fMVAPdfB) {
2107 fMVAPdfB->GetOriginalHist()->Write();
2108 fMVAPdfB->GetSmoothedHist()->Write();
2109 fMVAPdfB->GetPDFHist()->Write();
2110 }
2111
2112 // write result-histograms
2113 Results* results = Data()->GetResults( GetMethodName(), treetype, Types::kMaxAnalysisType );
2114 if (!results)
2115 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<WriteEvaluationHistosToFile> Unknown result: "
2116 << GetMethodName() << (treetype==Types::kTraining?"/kTraining":"/kTesting")
2117 << "/kMaxAnalysisType" << Endl;
2118 results->GetStorage()->Write();
2119 if (treetype==Types::kTesting) {
2120 // skipping plotting of variables if too many (default is 200)
2121 if ((int) DataInfo().GetNVariables()< gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables)
2122 GetTransformationHandler().PlotVariables (GetEventCollection( Types::kTesting ), BaseDir() );
2123 else
2124 Log() << kINFO << TString::Format("Dataset[%s] : ",DataInfo().GetName())
2125 << " variable plots are not produces ! The number of variables is " << DataInfo().GetNVariables()
2126 << " , it is larger than " << gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables << Endl;
2127 }
2128}
2129
2130////////////////////////////////////////////////////////////////////////////////
2131/// write special monitoring histograms to file
2132/// dummy implementation here -----------------
2133
2135{
2136}
2137
2138////////////////////////////////////////////////////////////////////////////////
2139/// reads one line from the input stream
2140/// checks for certain keywords and interprets
2141/// the line if keywords are found
2142
2143Bool_t TMVA::MethodBase::GetLine(std::istream& fin, char* buf )
2144{
2145 fin.getline(buf,512);
2146 TString line(buf);
2147 if (line.BeginsWith("TMVA Release")) {
2148 Ssiz_t start = line.First('[')+1;
2149 Ssiz_t length = line.Index("]",start)-start;
2150 TString code = line(start,length);
2151 std::stringstream s(code.Data());
2152 s >> fTMVATrainingVersion;
2153 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
2154 }
2155 if (line.BeginsWith("ROOT Release")) {
2156 Ssiz_t start = line.First('[')+1;
2157 Ssiz_t length = line.Index("]",start)-start;
2158 TString code = line(start,length);
2159 std::stringstream s(code.Data());
2160 s >> fROOTTrainingVersion;
2161 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
2162 }
2163 if (line.BeginsWith("Analysis type")) {
2164 Ssiz_t start = line.First('[')+1;
2165 Ssiz_t length = line.Index("]",start)-start;
2166 TString code = line(start,length);
2167 std::stringstream s(code.Data());
2168 std::string analysisType;
2169 s >> analysisType;
2170 if (analysisType == "regression" || analysisType == "Regression") SetAnalysisType( Types::kRegression );
2171 else if (analysisType == "classification" || analysisType == "Classification") SetAnalysisType( Types::kClassification );
2172 else if (analysisType == "multiclass" || analysisType == "Multiclass") SetAnalysisType( Types::kMulticlass );
2173 else Log() << kFATAL << "Analysis type " << analysisType << " from weight-file not known!" << std::endl;
2174
2175 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Method was trained for "
2176 << (GetAnalysisType() == Types::kRegression ? "Regression" :
2177 (GetAnalysisType() == Types::kMulticlass ? "Multiclass" : "Classification")) << Endl;
2178 }
2179
2180 return true;
2181}
2182
2183////////////////////////////////////////////////////////////////////////////////
2184/// Create PDFs of the MVA output variables
2185
2187{
2188 Data()->SetCurrentType(Types::kTraining);
2189
2190 // the PDF's are stored as results ONLY if the corresponding "results" are booked,
2191 // otherwise they will be only used 'online'
2192 ResultsClassification * mvaRes = dynamic_cast<ResultsClassification*>
2193 ( Data()->GetResults(GetMethodName(), Types::kTraining, Types::kClassification) );
2194
2195 if (mvaRes==0 || mvaRes->GetSize()==0) {
2196 Log() << kERROR<<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CreateMVAPdfs> No result of classifier testing available" << Endl;
2197 }
2198
2199 Double_t minVal = *std::min_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2200 Double_t maxVal = *std::max_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2201
2202 // create histograms that serve as basis to create the MVA Pdfs
2203 TH1* histMVAPdfS = new TH1D( GetMethodTypeName() + "_tr_S", GetMethodTypeName() + "_tr_S",
2204 fMVAPdfS->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2205 TH1* histMVAPdfB = new TH1D( GetMethodTypeName() + "_tr_B", GetMethodTypeName() + "_tr_B",
2206 fMVAPdfB->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2207
2208
2209 // compute sum of weights properly
2210 histMVAPdfS->Sumw2();
2211 histMVAPdfB->Sumw2();
2212
2213 // fill histograms
2214 for (UInt_t ievt=0; ievt<mvaRes->GetSize(); ievt++) {
2215 Double_t theVal = mvaRes->GetValueVector()->at(ievt);
2216 Double_t theWeight = Data()->GetEvent(ievt)->GetWeight();
2217
2218 if (DataInfo().IsSignal(Data()->GetEvent(ievt))) histMVAPdfS->Fill( theVal, theWeight );
2219 else histMVAPdfB->Fill( theVal, theWeight );
2220 }
2221
2222 gTools().NormHist( histMVAPdfS );
2223 gTools().NormHist( histMVAPdfB );
2224
2225 // momentary hack for ROOT problem
2226 if(!IsSilentFile())
2227 {
2228 histMVAPdfS->Write();
2229 histMVAPdfB->Write();
2230 }
2231 // create PDFs
2232 fMVAPdfS->BuildPDF ( histMVAPdfS );
2233 fMVAPdfB->BuildPDF ( histMVAPdfB );
2234 fMVAPdfS->ValidatePDF( histMVAPdfS );
2235 fMVAPdfB->ValidatePDF( histMVAPdfB );
2236
2237 if (DataInfo().GetNClasses() == 2) { // TODO: this is an ugly hack.. adapt this to new framework
2238 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())
2239 << Form( "<CreateMVAPdfs> Separation from histogram (PDF): %1.3f (%1.3f)",
2240 GetSeparation( histMVAPdfS, histMVAPdfB ), GetSeparation( fMVAPdfS, fMVAPdfB ) )
2241 << Endl;
2242 }
2243
2244 delete histMVAPdfS;
2245 delete histMVAPdfB;
2246}
2247
2249 // the simple one, automatically calculates the mvaVal and uses the
2250 // SAME sig/bkg ratio as given in the training sample (typically 50/50
2251 // .. (NormMode=EqualNumEvents) but can be different)
2252 if (!fMVAPdfS || !fMVAPdfB) {
2253 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<GetProba> MVA PDFs for Signal and Background don't exist yet, we'll create them on demand" << Endl;
2254 CreateMVAPdfs();
2255 }
2256 Double_t sigFraction = DataInfo().GetTrainingSumSignalWeights() / (DataInfo().GetTrainingSumSignalWeights() + DataInfo().GetTrainingSumBackgrWeights() );
2257 Double_t mvaVal = GetMvaValue(ev);
2258
2259 return GetProba(mvaVal,sigFraction);
2260
2261}
2262////////////////////////////////////////////////////////////////////////////////
2263/// compute likelihood ratio
2264
2266{
2267 if (!fMVAPdfS || !fMVAPdfB) {
2268 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetProba> MVA PDFs for Signal and Background don't exist" << Endl;
2269 return -1.0;
2270 }
2271 Double_t p_s = fMVAPdfS->GetVal( mvaVal );
2272 Double_t p_b = fMVAPdfB->GetVal( mvaVal );
2273
2274 Double_t denom = p_s*ap_sig + p_b*(1 - ap_sig);
2275
2276 return (denom > 0) ? (p_s*ap_sig) / denom : -1;
2277}
2278
2279////////////////////////////////////////////////////////////////////////////////
2280/// compute rarity:
2281/// \f[
2282/// R(x) = \int_{[-\infty..x]} { PDF(x') dx' }
2283/// \f]
2284/// where PDF(x) is the PDF of the classifier's signal or background distribution
2285
2287{
2288 if ((reftype == Types::kSignal && !fMVAPdfS) || (reftype == Types::kBackground && !fMVAPdfB)) {
2289 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetRarity> Required MVA PDF for Signal or Background does not exist: "
2290 << "select option \"CreateMVAPdfs\"" << Endl;
2291 return 0.0;
2292 }
2293
2294 PDF* thePdf = ((reftype == Types::kSignal) ? fMVAPdfS : fMVAPdfB);
2295
2296 return thePdf->GetIntegral( thePdf->GetXmin(), mvaVal );
2297}
2298
2299////////////////////////////////////////////////////////////////////////////////
2300/// fill background efficiency (resp. rejection) versus signal efficiency plots
2301/// returns signal efficiency at background efficiency indicated in theString
2302
2304{
2305 Data()->SetCurrentType(type);
2306 Results* results = Data()->GetResults( GetMethodName(), type, Types::kClassification );
2307 std::vector<Float_t>* mvaRes = dynamic_cast<ResultsClassification*>(results)->GetValueVector();
2308
2309 // parse input string for required background efficiency
2310 TList* list = gTools().ParseFormatLine( theString );
2311
2312 // sanity check
2313 Bool_t computeArea = kFALSE;
2314 if (!list || list->GetSize() < 2) computeArea = kTRUE; // the area is computed
2315 else if (list->GetSize() > 2) {
2316 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Wrong number of arguments"
2317 << " in string: " << theString
2318 << " | required format, e.g., Efficiency:0.05, or empty string" << Endl;
2319 delete list;
2320 return -1;
2321 }
2322
2323 // sanity check
2324 if ( results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2325 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2326 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Binning mismatch between signal and background histos" << Endl;
2327 delete list;
2328 return -1.0;
2329 }
2330
2331 // create histograms
2332
2333 // first, get efficiency histograms for signal and background
2334 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2335 Double_t xmin = effhist->GetXaxis()->GetXmin();
2336 Double_t xmax = effhist->GetXaxis()->GetXmax();
2337
2338 TTHREAD_TLS(Double_t) nevtS;
2339
2340 // first round ? --> create histograms
2341 if (results->DoesExist("MVA_EFF_S")==0) {
2342
2343 // for efficiency plot
2344 TH1* eff_s = new TH1D( GetTestvarName() + "_effS", GetTestvarName() + " (signal)", fNbinsH, xmin, xmax );
2345 TH1* eff_b = new TH1D( GetTestvarName() + "_effB", GetTestvarName() + " (background)", fNbinsH, xmin, xmax );
2346 results->Store(eff_s, "MVA_EFF_S");
2347 results->Store(eff_b, "MVA_EFF_B");
2348
2349 // sign if cut
2350 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2351
2352 // this method is unbinned
2353 nevtS = 0;
2354 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2355
2356 // read the tree
2357 Bool_t isSignal = DataInfo().IsSignal(GetEvent(ievt));
2358 Float_t theWeight = GetEvent(ievt)->GetWeight();
2359 Float_t theVal = (*mvaRes)[ievt];
2360
2361 // select histogram depending on if sig or bgd
2362 TH1* theHist = isSignal ? eff_s : eff_b;
2363
2364 // count signal and background events in tree
2365 if (isSignal) nevtS+=theWeight;
2366
2367 TAxis* axis = theHist->GetXaxis();
2368 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2369 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2370 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2371 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2372 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2373
2374 if (sign > 0)
2375 for (Int_t ibin=1; ibin<=maxbin; ibin++) theHist->AddBinContent( ibin , theWeight);
2376 else if (sign < 0)
2377 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theHist->AddBinContent( ibin , theWeight );
2378 else
2379 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Mismatch in sign" << Endl;
2380 }
2381
2382 // renormalise maximum to <=1
2383 // eff_s->Scale( 1.0/TMath::Max(1.,eff_s->GetMaximum()) );
2384 // eff_b->Scale( 1.0/TMath::Max(1.,eff_b->GetMaximum()) );
2385
2388
2389 // background efficiency versus signal efficiency
2390 TH1* eff_BvsS = new TH1D( GetTestvarName() + "_effBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2391 results->Store(eff_BvsS, "MVA_EFF_BvsS");
2392 eff_BvsS->SetXTitle( "Signal eff" );
2393 eff_BvsS->SetYTitle( "Backgr eff" );
2394
2395 // background rejection (=1-eff.) versus signal efficiency
2396 TH1* rej_BvsS = new TH1D( GetTestvarName() + "_rejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2397 results->Store(rej_BvsS);
2398 rej_BvsS->SetXTitle( "Signal eff" );
2399 rej_BvsS->SetYTitle( "Backgr rejection (1-eff)" );
2400
2401 // inverse background eff (1/eff.) versus signal efficiency
2402 TH1* inveff_BvsS = new TH1D( GetTestvarName() + "_invBeffvsSeff",
2403 GetTestvarName(), fNbins, 0, 1 );
2404 results->Store(inveff_BvsS);
2405 inveff_BvsS->SetXTitle( "Signal eff" );
2406 inveff_BvsS->SetYTitle( "Inverse backgr. eff (1/eff)" );
2407
2408 // use root finder
2409 // spline background efficiency plot
2410 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2412 fSplRefS = new TSpline1( "spline2_signal", new TGraph( eff_s ) );
2413 fSplRefB = new TSpline1( "spline2_background", new TGraph( eff_b ) );
2414
2415 // verify spline sanity
2416 gTools().CheckSplines( eff_s, fSplRefS );
2417 gTools().CheckSplines( eff_b, fSplRefB );
2418 }
2419
2420 // make the background-vs-signal efficiency plot
2421
2422 // create root finder
2423 RootFinder rootFinder( this, fXmin, fXmax );
2424
2425 Double_t effB = 0;
2426 fEffS = eff_s; // to be set for the root finder
2427 for (Int_t bini=1; bini<=fNbins; bini++) {
2428
2429 // find cut value corresponding to a given signal efficiency
2430 Double_t effS = eff_BvsS->GetBinCenter( bini );
2431 Double_t cut = rootFinder.Root( effS );
2432
2433 // retrieve background efficiency for given cut
2434 if (Use_Splines_for_Eff_) effB = fSplRefB->Eval( cut );
2435 else effB = eff_b->GetBinContent( eff_b->FindBin( cut ) );
2436
2437 // and fill histograms
2438 eff_BvsS->SetBinContent( bini, effB );
2439 rej_BvsS->SetBinContent( bini, 1.0-effB );
2441 inveff_BvsS->SetBinContent( bini, 1.0/effB );
2442 }
2443
2444 // create splines for histogram
2445 fSpleffBvsS = new TSpline1( "effBvsS", new TGraph( eff_BvsS ) );
2446
2447 // search for overlap point where, when cutting on it,
2448 // one would obtain: eff_S = rej_B = 1 - eff_B
2449 Double_t effS = 0., rejB, effS_ = 0., rejB_ = 0.;
2450 Int_t nbins_ = 5000;
2451 for (Int_t bini=1; bini<=nbins_; bini++) {
2452
2453 // get corresponding signal and background efficiencies
2454 effS = (bini - 0.5)/Float_t(nbins_);
2455 rejB = 1.0 - fSpleffBvsS->Eval( effS );
2456
2457 // find signal efficiency that corresponds to required background efficiency
2458 if ((effS - rejB)*(effS_ - rejB_) < 0) break;
2459 effS_ = effS;
2460 rejB_ = rejB;
2461 }
2462
2463 // find cut that corresponds to signal efficiency and update signal-like criterion
2464 Double_t cut = rootFinder.Root( 0.5*(effS + effS_) );
2465 SetSignalReferenceCut( cut );
2466 fEffS = 0;
2467 }
2468
2469 // must exist...
2470 if (0 == fSpleffBvsS) {
2471 delete list;
2472 return 0.0;
2473 }
2474
2475 // now find signal efficiency that corresponds to required background efficiency
2476 Double_t effS = 0, effB = 0, effS_ = 0, effB_ = 0;
2477 Int_t nbins_ = 1000;
2478
2479 if (computeArea) {
2480
2481 // compute area of rej-vs-eff plot
2482 Double_t integral = 0;
2483 for (Int_t bini=1; bini<=nbins_; bini++) {
2484
2485 // get corresponding signal and background efficiencies
2486 effS = (bini - 0.5)/Float_t(nbins_);
2487 effB = fSpleffBvsS->Eval( effS );
2488 integral += (1.0 - effB);
2489 }
2490 integral /= nbins_;
2491
2492 delete list;
2493 return integral;
2494 }
2495 else {
2496
2497 // that will be the value of the efficiency retured (does not affect
2498 // the efficiency-vs-bkg plot which is done anyway.
2499 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2500
2501 // find precise efficiency value
2502 for (Int_t bini=1; bini<=nbins_; bini++) {
2503
2504 // get corresponding signal and background efficiencies
2505 effS = (bini - 0.5)/Float_t(nbins_);
2506 effB = fSpleffBvsS->Eval( effS );
2507
2508 // find signal efficiency that corresponds to required background efficiency
2509 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2510 effS_ = effS;
2511 effB_ = effB;
2512 }
2513
2514 // take mean between bin above and bin below
2515 effS = 0.5*(effS + effS_);
2516
2517 effSerr = 0;
2518 if (nevtS > 0) effSerr = TMath::Sqrt( effS*(1.0 - effS)/nevtS );
2519
2520 delete list;
2521 return effS;
2522 }
2523
2524 return -1;
2525}
2526
2527////////////////////////////////////////////////////////////////////////////////
2528
2530{
2531 Data()->SetCurrentType(Types::kTraining);
2532
2533 Results* results = Data()->GetResults(GetMethodName(), Types::kTesting, Types::kNoAnalysisType);
2534
2535 // fill background efficiency (resp. rejection) versus signal efficiency plots
2536 // returns signal efficiency at background efficiency indicated in theString
2537
2538 // parse input string for required background efficiency
2539 TList* list = gTools().ParseFormatLine( theString );
2540 // sanity check
2541
2542 if (list->GetSize() != 2) {
2543 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Wrong number of arguments"
2544 << " in string: " << theString
2545 << " | required format, e.g., Efficiency:0.05" << Endl;
2546 delete list;
2547 return -1;
2548 }
2549 // that will be the value of the efficiency retured (does not affect
2550 // the efficiency-vs-bkg plot which is done anyway.
2551 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2552
2553 delete list;
2554
2555 // sanity check
2556 if (results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2557 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2558 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Binning mismatch between signal and background histos"
2559 << Endl;
2560 return -1.0;
2561 }
2562
2563 // create histogram
2564
2565 // first, get efficiency histograms for signal and background
2566 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2567 Double_t xmin = effhist->GetXaxis()->GetXmin();
2568 Double_t xmax = effhist->GetXaxis()->GetXmax();
2569
2570 // first round ? --> create and fill histograms
2571 if (results->DoesExist("MVA_TRAIN_S")==0) {
2572
2573 // classifier response distributions for test sample
2574 Double_t sxmax = fXmax+0.00001;
2575
2576 // MVA plots on the training sample (check for overtraining)
2577 TH1* mva_s_tr = new TH1D( GetTestvarName() + "_Train_S",GetTestvarName() + "_Train_S", fNbinsMVAoutput, fXmin, sxmax );
2578 TH1* mva_b_tr = new TH1D( GetTestvarName() + "_Train_B",GetTestvarName() + "_Train_B", fNbinsMVAoutput, fXmin, sxmax );
2579 results->Store(mva_s_tr, "MVA_TRAIN_S");
2580 results->Store(mva_b_tr, "MVA_TRAIN_B");
2581 mva_s_tr->Sumw2();
2582 mva_b_tr->Sumw2();
2583
2584 // Training efficiency plots
2585 TH1* mva_eff_tr_s = new TH1D( GetTestvarName() + "_trainingEffS", GetTestvarName() + " (signal)",
2586 fNbinsH, xmin, xmax );
2587 TH1* mva_eff_tr_b = new TH1D( GetTestvarName() + "_trainingEffB", GetTestvarName() + " (background)",
2588 fNbinsH, xmin, xmax );
2589 results->Store(mva_eff_tr_s, "MVA_TRAINEFF_S");
2590 results->Store(mva_eff_tr_b, "MVA_TRAINEFF_B");
2591
2592 // sign if cut
2593 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2594
2595 std::vector<Double_t> mvaValues = GetMvaValues(0,Data()->GetNEvents());
2596 assert( (Long64_t) mvaValues.size() == Data()->GetNEvents());
2597
2598 // this method is unbinned
2599 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2600
2601 Data()->SetCurrentEvent(ievt);
2602 const Event* ev = GetEvent();
2603
2604 Double_t theVal = mvaValues[ievt];
2605 Double_t theWeight = ev->GetWeight();
2606
2607 TH1* theEffHist = DataInfo().IsSignal(ev) ? mva_eff_tr_s : mva_eff_tr_b;
2608 TH1* theClsHist = DataInfo().IsSignal(ev) ? mva_s_tr : mva_b_tr;
2609
2610 theClsHist->Fill( theVal, theWeight );
2611
2612 TAxis* axis = theEffHist->GetXaxis();
2613 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2614 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2615 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2616 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2617 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2618
2619 if (sign > 0) for (Int_t ibin=1; ibin<=maxbin; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2620 else for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2621 }
2622
2623 // normalise output distributions
2624 // uncomment those (and several others if you want unnormalized output
2625 gTools().NormHist( mva_s_tr );
2626 gTools().NormHist( mva_b_tr );
2627
2628 // renormalise to maximum
2629 mva_eff_tr_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_s->GetMaximum()) );
2630 mva_eff_tr_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_b->GetMaximum()) );
2631
2632 // Training background efficiency versus signal efficiency
2633 TH1* eff_bvss = new TH1D( GetTestvarName() + "_trainingEffBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2634 // Training background rejection (=1-eff.) versus signal efficiency
2635 TH1* rej_bvss = new TH1D( GetTestvarName() + "_trainingRejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2636 results->Store(eff_bvss, "EFF_BVSS_TR");
2637 results->Store(rej_bvss, "REJ_BVSS_TR");
2638
2639 // use root finder
2640 // spline background efficiency plot
2641 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2643 if (fSplTrainRefS) delete fSplTrainRefS;
2644 if (fSplTrainRefB) delete fSplTrainRefB;
2645 fSplTrainRefS = new TSpline1( "spline2_signal", new TGraph( mva_eff_tr_s ) );
2646 fSplTrainRefB = new TSpline1( "spline2_background", new TGraph( mva_eff_tr_b ) );
2647
2648 // verify spline sanity
2649 gTools().CheckSplines( mva_eff_tr_s, fSplTrainRefS );
2650 gTools().CheckSplines( mva_eff_tr_b, fSplTrainRefB );
2651 }
2652
2653 // make the background-vs-signal efficiency plot
2654
2655 // create root finder
2656 RootFinder rootFinder(this, fXmin, fXmax );
2657
2658 Double_t effB = 0;
2659 fEffS = results->GetHist("MVA_TRAINEFF_S");
2660 for (Int_t bini=1; bini<=fNbins; bini++) {
2661
2662 // find cut value corresponding to a given signal efficiency
2663 Double_t effS = eff_bvss->GetBinCenter( bini );
2664
2665 Double_t cut = rootFinder.Root( effS );
2666
2667 // retrieve background efficiency for given cut
2668 if (Use_Splines_for_Eff_) effB = fSplTrainRefB->Eval( cut );
2669 else effB = mva_eff_tr_b->GetBinContent( mva_eff_tr_b->FindBin( cut ) );
2670
2671 // and fill histograms
2672 eff_bvss->SetBinContent( bini, effB );
2673 rej_bvss->SetBinContent( bini, 1.0-effB );
2674 }
2675 fEffS = 0;
2676
2677 // create splines for histogram
2678 fSplTrainEffBvsS = new TSpline1( "effBvsS", new TGraph( eff_bvss ) );
2679 }
2680
2681 // must exist...
2682 if (0 == fSplTrainEffBvsS) return 0.0;
2683
2684 // now find signal efficiency that corresponds to required background efficiency
2685 Double_t effS = 0., effB, effS_ = 0., effB_ = 0.;
2686 Int_t nbins_ = 1000;
2687 for (Int_t bini=1; bini<=nbins_; bini++) {
2688
2689 // get corresponding signal and background efficiencies
2690 effS = (bini - 0.5)/Float_t(nbins_);
2691 effB = fSplTrainEffBvsS->Eval( effS );
2692
2693 // find signal efficiency that corresponds to required background efficiency
2694 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2695 effS_ = effS;
2696 effB_ = effB;
2697 }
2698
2699 return 0.5*(effS + effS_); // the mean between bin above and bin below
2700}
2701
2702////////////////////////////////////////////////////////////////////////////////
2703
2704std::vector<Float_t> TMVA::MethodBase::GetMulticlassEfficiency(std::vector<std::vector<Float_t> >& purity)
2705{
2706 Data()->SetCurrentType(Types::kTesting);
2707 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
2708 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in GetMulticlassEfficiency, exiting."<<Endl;
2709
2710 purity.push_back(resMulticlass->GetAchievablePur());
2711 return resMulticlass->GetAchievableEff();
2712}
2713
2714////////////////////////////////////////////////////////////////////////////////
2715
2716std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity)
2717{
2718 Data()->SetCurrentType(Types::kTraining);
2719 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTraining, Types::kMulticlass));
2720 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in GetMulticlassTrainingEfficiency, exiting."<<Endl;
2721
2722 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for training data..." << Endl;
2723 for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
2724 resMulticlass->GetBestMultiClassCuts(icls);
2725 }
2726
2727 purity.push_back(resMulticlass->GetAchievablePur());
2728 return resMulticlass->GetAchievableEff();
2729}
2730
2731////////////////////////////////////////////////////////////////////////////////
2732/// Construct a confusion matrix for a multiclass classifier. The confusion
2733/// matrix compares, in turn, each class agaist all other classes in a pair-wise
2734/// fashion. In rows with index \f$ k_r = 0 ... K \f$, \f$ k_r \f$ is
2735/// considered signal for the sake of comparison and for each column
2736/// \f$ k_c = 0 ... K \f$ the corresponding class is considered background.
2737///
2738/// Note that the diagonal elements will be returned as NaN since this will
2739/// compare a class against itself.
2740///
2741/// \see TMVA::ResultsMulticlass::GetConfusionMatrix
2742///
2743/// \param[in] effB The background efficiency for which to evaluate.
2744/// \param[in] type The data set on which to evaluate (training, testing ...).
2745///
2746/// \return A matrix containing signal efficiencies for the given background
2747/// efficiency. The diagonal elements are NaN since this measure is
2748/// meaningless (comparing a class against itself).
2749///
2750
2752{
2753 if (GetAnalysisType() != Types::kMulticlass) {
2754 Log() << kFATAL << "Cannot get confusion matrix for non-multiclass analysis." << std::endl;
2755 return TMatrixD(0, 0);
2756 }
2757
2758 Data()->SetCurrentType(type);
2759 ResultsMulticlass *resMulticlass =
2760 dynamic_cast<ResultsMulticlass *>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
2761
2762 if (resMulticlass == nullptr) {
2763 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
2764 << "unable to create pointer in GetMulticlassEfficiency, exiting." << Endl;
2765 return TMatrixD(0, 0);
2766 }
2767
2768 return resMulticlass->GetConfusionMatrix(effB);
2769}
2770
2771////////////////////////////////////////////////////////////////////////////////
2772/// compute significance of mean difference
2773/// \f[
2774/// significance = \frac{|<S> - <B>|}{\sqrt{RMS_{S2} + RMS_{B2}}}
2775/// \f]
2776
2778{
2779 Double_t rms = sqrt( fRmsS*fRmsS + fRmsB*fRmsB );
2780
2781 return (rms > 0) ? TMath::Abs(fMeanS - fMeanB)/rms : 0;
2782}
2783
2784////////////////////////////////////////////////////////////////////////////////
2785/// compute "separation" defined as
2786/// \f[
2787/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2788/// \f]
2789
2791{
2792 return gTools().GetSeparation( histoS, histoB );
2793}
2794
2795////////////////////////////////////////////////////////////////////////////////
2796/// compute "separation" defined as
2797/// \f[
2798/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2799/// \f]
2800
2802{
2803 // note, if zero pointers given, use internal pdf
2804 // sanity check first
2805 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2806 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2807 if (!pdfS) pdfS = fSplS;
2808 if (!pdfB) pdfB = fSplB;
2809
2810 if (!fSplS || !fSplB) {
2811 Log()<<kDEBUG<<Form("[%s] : ",DataInfo().GetName())<< "could not calculate the separation, distributions"
2812 << " fSplS or fSplB are not yet filled" << Endl;
2813 return 0;
2814 }else{
2815 return gTools().GetSeparation( *pdfS, *pdfB );
2816 }
2817}
2818
2819////////////////////////////////////////////////////////////////////////////////
2820/// calculate the area (integral) under the ROC curve as a
2821/// overall quality measure of the classification
2822
2824{
2825 // note, if zero pointers given, use internal pdf
2826 // sanity check first
2827 if ((!histS && histB) || (histS && !histB))
2828 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetROCIntegral(TH1D*, TH1D*)> Mismatch in hists" << Endl;
2829
2830 if (histS==0 || histB==0) return 0.;
2831
2832 TMVA::PDF *pdfS = new TMVA::PDF( " PDF Sig", histS, TMVA::PDF::kSpline3 );
2833 TMVA::PDF *pdfB = new TMVA::PDF( " PDF Bkg", histB, TMVA::PDF::kSpline3 );
2834
2835
2836 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2837 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2838
2839 Double_t integral = 0;
2840 UInt_t nsteps = 1000;
2841 Double_t step = (xmax-xmin)/Double_t(nsteps);
2842 Double_t cut = xmin;
2843 for (UInt_t i=0; i<nsteps; i++) {
2844 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2845 cut+=step;
2846 }
2847 delete pdfS;
2848 delete pdfB;
2849 return integral*step;
2850}
2851
2852
2853////////////////////////////////////////////////////////////////////////////////
2854/// calculate the area (integral) under the ROC curve as a
2855/// overall quality measure of the classification
2856
2858{
2859 // note, if zero pointers given, use internal pdf
2860 // sanity check first
2861 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2862 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2863 if (!pdfS) pdfS = fSplS;
2864 if (!pdfB) pdfB = fSplB;
2865
2866 if (pdfS==0 || pdfB==0) return 0.;
2867
2868 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2869 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2870
2871 Double_t integral = 0;
2872 UInt_t nsteps = 1000;
2873 Double_t step = (xmax-xmin)/Double_t(nsteps);
2874 Double_t cut = xmin;
2875 for (UInt_t i=0; i<nsteps; i++) {
2876 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2877 cut+=step;
2878 }
2879 return integral*step;
2880}
2881
2882////////////////////////////////////////////////////////////////////////////////
2883/// plot significance, \f$ \frac{S}{\sqrt{S^2 + B^2}} \f$, curve for given number
2884/// of signal and background events; returns cut for maximum significance
2885/// also returned via reference is the maximum significance
2886
2888 Double_t BackgroundEvents,
2889 Double_t& max_significance_value ) const
2890{
2891 Results* results = Data()->GetResults( GetMethodName(), Types::kTesting, Types::kMaxAnalysisType );
2892
2893 Double_t max_significance(0);
2894 Double_t effS(0),effB(0),significance(0);
2895 TH1D *temp_histogram = new TH1D("temp", "temp", fNbinsH, fXmin, fXmax );
2896
2897 if (SignalEvents <= 0 || BackgroundEvents <= 0) {
2898 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetMaximumSignificance> "
2899 << "Number of signal or background events is <= 0 ==> abort"
2900 << Endl;
2901 }
2902
2903 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Using ratio SignalEvents/BackgroundEvents = "
2904 << SignalEvents/BackgroundEvents << Endl;
2905
2906 TH1* eff_s = results->GetHist("MVA_EFF_S");
2907 TH1* eff_b = results->GetHist("MVA_EFF_B");
2908
2909 if ( (eff_s==0) || (eff_b==0) ) {
2910 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Efficiency histograms empty !" << Endl;
2911 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "no maximum cut found, return 0" << Endl;
2912 return 0;
2913 }
2914
2915 for (Int_t bin=1; bin<=fNbinsH; bin++) {
2916 effS = eff_s->GetBinContent( bin );
2917 effB = eff_b->GetBinContent( bin );
2918
2919 // put significance into a histogram
2920 significance = sqrt(SignalEvents)*( effS )/sqrt( effS + ( BackgroundEvents / SignalEvents) * effB );
2921
2922 temp_histogram->SetBinContent(bin,significance);
2923 }
2924
2925 // find maximum in histogram
2926 max_significance = temp_histogram->GetBinCenter( temp_histogram->GetMaximumBin() );
2927 max_significance_value = temp_histogram->GetBinContent( temp_histogram->GetMaximumBin() );
2928
2929 // delete
2930 delete temp_histogram;
2931
2932 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Optimal cut at : " << max_significance << Endl;
2933 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "Maximum significance: " << max_significance_value << Endl;
2934
2935 return max_significance;
2936}
2937
2938////////////////////////////////////////////////////////////////////////////////
2939/// calculates rms,mean, xmin, xmax of the event variable
2940/// this can be either done for the variables as they are or for
2941/// normalised variables (in the range of 0-1) if "norm" is set to kTRUE
2942
2944 Double_t& meanS, Double_t& meanB,
2945 Double_t& rmsS, Double_t& rmsB,
2947{
2948 Types::ETreeType previousTreeType = Data()->GetCurrentType();
2949 Data()->SetCurrentType(treeType);
2950
2951 Long64_t entries = Data()->GetNEvents();
2952
2953 // sanity check
2954 if (entries <=0)
2955 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CalculateEstimator> Wrong tree type: " << treeType << Endl;
2956
2957 // index of the wanted variable
2958 UInt_t varIndex = DataInfo().FindVarIndex( theVarName );
2959
2960 // first fill signal and background in arrays before analysis
2961 xmin = +DBL_MAX;
2962 xmax = -DBL_MAX;
2963
2964 // take into account event weights
2965 meanS = 0;
2966 meanB = 0;
2967 rmsS = 0;
2968 rmsB = 0;
2969 Double_t sumwS = 0, sumwB = 0;
2970
2971 // loop over all training events
2972 for (Int_t ievt = 0; ievt < entries; ievt++) {
2973
2974 const Event* ev = GetEvent(ievt);
2975
2976 Double_t theVar = ev->GetValue(varIndex);
2977 Double_t weight = ev->GetWeight();
2978
2979 if (DataInfo().IsSignal(ev)) {
2980 sumwS += weight;
2981 meanS += weight*theVar;
2982 rmsS += weight*theVar*theVar;
2983 }
2984 else {
2985 sumwB += weight;
2986 meanB += weight*theVar;
2987 rmsB += weight*theVar*theVar;
2988 }
2989 xmin = TMath::Min( xmin, theVar );
2990 xmax = TMath::Max( xmax, theVar );
2991 }
2992
2993 meanS = meanS/sumwS;
2994 meanB = meanB/sumwB;
2995 rmsS = TMath::Sqrt( rmsS/sumwS - meanS*meanS );
2996 rmsB = TMath::Sqrt( rmsB/sumwB - meanB*meanB );
2997
2998 Data()->SetCurrentType(previousTreeType);
2999}
3000
3001////////////////////////////////////////////////////////////////////////////////
3002/// create reader class for method (classification only at present)
3003
3004void TMVA::MethodBase::MakeClass( const TString& theClassFileName ) const
3005{
3006 // the default consists of
3007 TString classFileName = "";
3008 if (theClassFileName == "")
3009 classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class.C";
3010 else
3011 classFileName = theClassFileName;
3012
3013 TString className = TString("Read") + GetMethodName();
3014
3015 TString tfname( classFileName );
3016 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
3017 << "Creating standalone class: "
3018 << gTools().Color("lightblue") << classFileName << gTools().Color("reset") << Endl;
3019
3020 std::ofstream fout( classFileName );
3021 if (!fout.good()) { // file could not be opened --> Error
3022 Log() << kFATAL << "<MakeClass> Unable to open file: " << classFileName << Endl;
3023 }
3024
3025 // now create the class
3026 // preamble
3027 fout << "// Class: " << className << std::endl;
3028 fout << "// Automatically generated by MethodBase::MakeClass" << std::endl << "//" << std::endl;
3029
3030 // print general information and configuration state
3031 fout << std::endl;
3032 fout << "/* configuration options =====================================================" << std::endl << std::endl;
3033 WriteStateToStream( fout );
3034 fout << std::endl;
3035 fout << "============================================================================ */" << std::endl;
3036
3037 // generate the class
3038 fout << "" << std::endl;
3039 fout << "#include <array>" << std::endl;
3040 fout << "#include <vector>" << std::endl;
3041 fout << "#include <cmath>" << std::endl;
3042 fout << "#include <string>" << std::endl;
3043 fout << "#include <iostream>" << std::endl;
3044 fout << "" << std::endl;
3045 // now if the classifier needs to write some additional classes for its response implementation
3046 // this code goes here: (at least the header declarations need to come before the main class
3047 this->MakeClassSpecificHeader( fout, className );
3048
3049 fout << "#ifndef IClassifierReader__def" << std::endl;
3050 fout << "#define IClassifierReader__def" << std::endl;
3051 fout << std::endl;
3052 fout << "class IClassifierReader {" << std::endl;
3053 fout << std::endl;
3054 fout << " public:" << std::endl;
3055 fout << std::endl;
3056 fout << " // constructor" << std::endl;
3057 fout << " IClassifierReader() : fStatusIsClean( true ) {}" << std::endl;
3058 fout << " virtual ~IClassifierReader() {}" << std::endl;
3059 fout << std::endl;
3060 fout << " // return classifier response" << std::endl;
3061 if(GetAnalysisType() == Types::kMulticlass) {
3062 fout << " virtual std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3063 } else {
3064 fout << " virtual double GetMvaValue( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3065 }
3066 fout << std::endl;
3067 fout << " // returns classifier status" << std::endl;
3068 fout << " bool IsStatusClean() const { return fStatusIsClean; }" << std::endl;
3069 fout << std::endl;
3070 fout << " protected:" << std::endl;
3071 fout << std::endl;
3072 fout << " bool fStatusIsClean;" << std::endl;
3073 fout << "};" << std::endl;
3074 fout << std::endl;
3075 fout << "#endif" << std::endl;
3076 fout << std::endl;
3077 fout << "class " << className << " : public IClassifierReader {" << std::endl;
3078 fout << std::endl;
3079 fout << " public:" << std::endl;
3080 fout << std::endl;
3081 fout << " // constructor" << std::endl;
3082 fout << " " << className << "( std::vector<std::string>& theInputVars )" << std::endl;
3083 fout << " : IClassifierReader()," << std::endl;
3084 fout << " fClassName( \"" << className << "\" )," << std::endl;
3085 fout << " fNvars( " << GetNvar() << " )" << std::endl;
3086 fout << " {" << std::endl;
3087 fout << " // the training input variables" << std::endl;
3088 fout << " const char* inputVars[] = { ";
3089 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3090 fout << "\"" << GetOriginalVarName(ivar) << "\"";
3091 if (ivar<GetNvar()-1) fout << ", ";
3092 }
3093 fout << " };" << std::endl;
3094 fout << std::endl;
3095 fout << " // sanity checks" << std::endl;
3096 fout << " if (theInputVars.size() <= 0) {" << std::endl;
3097 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": empty input vector\" << std::endl;" << std::endl;
3098 fout << " fStatusIsClean = false;" << std::endl;
3099 fout << " }" << std::endl;
3100 fout << std::endl;
3101 fout << " if (theInputVars.size() != fNvars) {" << std::endl;
3102 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in number of input values: \"" << std::endl;
3103 fout << " << theInputVars.size() << \" != \" << fNvars << std::endl;" << std::endl;
3104 fout << " fStatusIsClean = false;" << std::endl;
3105 fout << " }" << std::endl;
3106 fout << std::endl;
3107 fout << " // validate input variables" << std::endl;
3108 fout << " for (size_t ivar = 0; ivar < theInputVars.size(); ivar++) {" << std::endl;
3109 fout << " if (theInputVars[ivar] != inputVars[ivar]) {" << std::endl;
3110 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in input variable names\" << std::endl" << std::endl;
3111 fout << " << \" for variable [\" << ivar << \"]: \" << theInputVars[ivar].c_str() << \" != \" << inputVars[ivar] << std::endl;" << std::endl;
3112 fout << " fStatusIsClean = false;" << std::endl;
3113 fout << " }" << std::endl;
3114 fout << " }" << std::endl;
3115 fout << std::endl;
3116 fout << " // initialize min and max vectors (for normalisation)" << std::endl;
3117 for (UInt_t ivar = 0; ivar < GetNvar(); ivar++) {
3118 fout << " fVmin[" << ivar << "] = " << std::setprecision(15) << GetXmin( ivar ) << ";" << std::endl;
3119 fout << " fVmax[" << ivar << "] = " << std::setprecision(15) << GetXmax( ivar ) << ";" << std::endl;
3120 }
3121 fout << std::endl;
3122 fout << " // initialize input variable types" << std::endl;
3123 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3124 fout << " fType[" << ivar << "] = \'" << DataInfo().GetVariableInfo(ivar).GetVarType() << "\';" << std::endl;
3125 }
3126 fout << std::endl;
3127 fout << " // initialize constants" << std::endl;
3128 fout << " Initialize();" << std::endl;
3129 fout << std::endl;
3130 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
3131 fout << " // initialize transformation" << std::endl;
3132 fout << " InitTransform();" << std::endl;
3133 }
3134 fout << " }" << std::endl;
3135 fout << std::endl;
3136 fout << " // destructor" << std::endl;
3137 fout << " virtual ~" << className << "() {" << std::endl;
3138 fout << " Clear(); // method-specific" << std::endl;
3139 fout << " }" << std::endl;
3140 fout << std::endl;
3141 fout << " // the classifier response" << std::endl;
3142 fout << " // \"inputValues\" is a vector of input values in the same order as the" << std::endl;
3143 fout << " // variables given to the constructor" << std::endl;
3144 if(GetAnalysisType() == Types::kMulticlass) {
3145 fout << " std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const override;" << std::endl;
3146 } else {
3147 fout << " double GetMvaValue( const std::vector<double>& inputValues ) const override;" << std::endl;
3148 }
3149 fout << std::endl;
3150 fout << " private:" << std::endl;
3151 fout << std::endl;
3152 fout << " // method-specific destructor" << std::endl;
3153 fout << " void Clear();" << std::endl;
3154 fout << std::endl;
3155 if (GetTransformationHandler().GetTransformationList().GetSize()!=0) {
3156 fout << " // input variable transformation" << std::endl;
3157 GetTransformationHandler().MakeFunction(fout, className,1);
3158 fout << " void InitTransform();" << std::endl;
3159 fout << " void Transform( std::vector<double> & iv, int sigOrBgd ) const;" << std::endl;
3160 fout << std::endl;
3161 }
3162 fout << " // common member variables" << std::endl;
3163 fout << " const char* fClassName;" << std::endl;
3164 fout << std::endl;
3165 fout << " const size_t fNvars;" << std::endl;
3166 fout << " size_t GetNvar() const { return fNvars; }" << std::endl;
3167 fout << " char GetType( int ivar ) const { return fType[ivar]; }" << std::endl;
3168 fout << std::endl;
3169 fout << " // normalisation of input variables" << std::endl;
3170 fout << " double fVmin[" << GetNvar() << "];" << std::endl;
3171 fout << " double fVmax[" << GetNvar() << "];" << std::endl;
3172 fout << " double NormVariable( double x, double xmin, double xmax ) const {" << std::endl;
3173 fout << " // normalise to output range: [-1, 1]" << std::endl;
3174 fout << " return 2*(x - xmin)/(xmax - xmin) - 1.0;" << std::endl;
3175 fout << " }" << std::endl;
3176 fout << std::endl;
3177 fout << " // type of input variable: 'F' or 'I'" << std::endl;
3178 fout << " char fType[" << GetNvar() << "];" << std::endl;
3179 fout << std::endl;
3180 fout << " // initialize internal variables" << std::endl;
3181 fout << " void Initialize();" << std::endl;
3182 if(GetAnalysisType() == Types::kMulticlass) {
3183 fout << " std::vector<double> GetMulticlassValues__( const std::vector<double>& inputValues ) const;" << std::endl;
3184 } else {
3185 fout << " double GetMvaValue__( const std::vector<double>& inputValues ) const;" << std::endl;
3186 }
3187 fout << "" << std::endl;
3188 fout << " // private members (method specific)" << std::endl;
3189
3190 // call the classifier specific output (the classifier must close the class !)
3191 MakeClassSpecific( fout, className );
3192
3193 if(GetAnalysisType() == Types::kMulticlass) {
3194 fout << "inline std::vector<double> " << className << "::GetMulticlassValues( const std::vector<double>& inputValues ) const" << std::endl;
3195 } else {
3196 fout << "inline double " << className << "::GetMvaValue( const std::vector<double>& inputValues ) const" << std::endl;
3197 }
3198 fout << "{" << std::endl;
3199 fout << " // classifier response value" << std::endl;
3200 if(GetAnalysisType() == Types::kMulticlass) {
3201 fout << " std::vector<double> retval;" << std::endl;
3202 } else {
3203 fout << " double retval = 0;" << std::endl;
3204 }
3205 fout << std::endl;
3206 fout << " // classifier response, sanity check first" << std::endl;
3207 fout << " if (!IsStatusClean()) {" << std::endl;
3208 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": cannot return classifier response\"" << std::endl;
3209 fout << " << \" because status is dirty\" << std::endl;" << std::endl;
3210 fout << " }" << std::endl;
3211 fout << " else {" << std::endl;
3212 if (IsNormalised()) {
3213 fout << " // normalise variables" << std::endl;
3214 fout << " std::vector<double> iV;" << std::endl;
3215 fout << " iV.reserve(inputValues.size());" << std::endl;
3216 fout << " int ivar = 0;" << std::endl;
3217 fout << " for (std::vector<double>::const_iterator varIt = inputValues.begin();" << std::endl;
3218 fout << " varIt != inputValues.end(); varIt++, ivar++) {" << std::endl;
3219 fout << " iV.push_back(NormVariable( *varIt, fVmin[ivar], fVmax[ivar] ));" << std::endl;
3220 fout << " }" << std::endl;
3221 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3222 GetMethodType() != Types::kHMatrix) {
3223 fout << " Transform( iV, -1 );" << std::endl;
3224 }
3225
3226 if(GetAnalysisType() == Types::kMulticlass) {
3227 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3228 } else {
3229 fout << " retval = GetMvaValue__( iV );" << std::endl;
3230 }
3231 } else {
3232 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3233 GetMethodType() != Types::kHMatrix) {
3234 fout << " std::vector<double> iV(inputValues);" << std::endl;
3235 fout << " Transform( iV, -1 );" << std::endl;
3236 if(GetAnalysisType() == Types::kMulticlass) {
3237 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3238 } else {
3239 fout << " retval = GetMvaValue__( iV );" << std::endl;
3240 }
3241 } else {
3242 if(GetAnalysisType() == Types::kMulticlass) {
3243 fout << " retval = GetMulticlassValues__( inputValues );" << std::endl;
3244 } else {
3245 fout << " retval = GetMvaValue__( inputValues );" << std::endl;
3246 }
3247 }
3248 }
3249 fout << " }" << std::endl;
3250 fout << std::endl;
3251 fout << " return retval;" << std::endl;
3252 fout << "}" << std::endl;
3253
3254 // create output for transformation - if any
3255 if (GetTransformationHandler().GetTransformationList().GetSize()!=0)
3256 GetTransformationHandler().MakeFunction(fout, className,2);
3257
3258 // close the file
3259 fout.close();
3260}
3261
3262////////////////////////////////////////////////////////////////////////////////
3263/// prints out method-specific help method
3264
3266{
3267 // if options are written to reference file, also append help info
3268 std::streambuf* cout_sbuf = std::cout.rdbuf(); // save original sbuf
3269 std::ofstream* o = 0;
3270 if (gConfig().WriteOptionsReference()) {
3271 Log() << kINFO << "Print Help message for class " << GetName() << " into file: " << GetReferenceFile() << Endl;
3272 o = new std::ofstream( GetReferenceFile(), std::ios::app );
3273 if (!o->good()) { // file could not be opened --> Error
3274 Log() << kFATAL << "<PrintHelpMessage> Unable to append to output file: " << GetReferenceFile() << Endl;
3275 }
3276 std::cout.rdbuf( o->rdbuf() ); // redirect 'std::cout' to file
3277 }
3278
3279 // "|--------------------------------------------------------------|"
3280 if (!o) {
3281 Log() << kINFO << Endl;
3282 Log() << gTools().Color("bold")
3283 << "================================================================"
3284 << gTools().Color( "reset" )
3285 << Endl;
3286 Log() << gTools().Color("bold")
3287 << "H e l p f o r M V A m e t h o d [ " << GetName() << " ] :"
3288 << gTools().Color( "reset" )
3289 << Endl;
3290 }
3291 else {
3292 Log() << "Help for MVA method [ " << GetName() << " ] :" << Endl;
3293 }
3294
3295 // print method-specific help message
3296 GetHelpMessage();
3297
3298 if (!o) {
3299 Log() << Endl;
3300 Log() << "<Suppress this message by specifying \"!H\" in the booking option>" << Endl;
3301 Log() << gTools().Color("bold")
3302 << "================================================================"
3303 << gTools().Color( "reset" )
3304 << Endl;
3305 Log() << Endl;
3306 }
3307 else {
3308 // indicate END
3309 Log() << "# End of Message___" << Endl;
3310 }
3311
3312 std::cout.rdbuf( cout_sbuf ); // restore the original stream buffer
3313 if (o) o->close();
3314}
3315
3316// ----------------------- r o o t f i n d i n g ----------------------------
3317
3318////////////////////////////////////////////////////////////////////////////////
3319/// returns efficiency as function of cut
3320
3322{
3323 Double_t retval=0;
3324
3325 // retrieve the class object
3327 retval = fSplRefS->Eval( theCut );
3328 }
3329 else retval = fEffS->GetBinContent( fEffS->FindBin( theCut ) );
3330
3331 // caution: here we take some "forbidden" action to hide a problem:
3332 // in some cases, in particular for likelihood, the binned efficiency distributions
3333 // do not equal 1, at xmin, and 0 at xmax; of course, in principle we have the
3334 // unbinned information available in the trees, but the unbinned minimization is
3335 // too slow, and we don't need to do a precision measurement here. Hence, we force
3336 // this property.
3337 Double_t eps = 1.0e-5;
3338 if (theCut-fXmin < eps) retval = (GetCutOrientation() == kPositive) ? 1.0 : 0.0;
3339 else if (fXmax-theCut < eps) retval = (GetCutOrientation() == kPositive) ? 0.0 : 1.0;
3340
3341 return retval;
3342}
3343
3344////////////////////////////////////////////////////////////////////////////////
3345/// returns the event collection (i.e. the dataset) TRANSFORMED using the
3346/// classifiers specific Variable Transformation (e.g. Decorr or Decorr:Gauss:Decorr)
3347
3349{
3350 // if there's no variable transformation for this classifier, just hand back the
3351 // event collection of the data set
3352 if (GetTransformationHandler().GetTransformationList().GetEntries() <= 0) {
3353 return (Data()->GetEventCollection(type));
3354 }
3355
3356 // otherwise, transform ALL the events and hand back the vector of the pointers to the
3357 // transformed events. If the pointer is already != 0, i.e. the whole thing has been
3358 // done before, I don't need to do it again, but just "hand over" the pointer to those events.
3359 Int_t idx = Data()->TreeIndex(type); //index indicating Training,Testing,... events/datasets
3360 if (fEventCollections.at(idx) == 0) {
3361 fEventCollections.at(idx) = &(Data()->GetEventCollection(type));
3362 fEventCollections.at(idx) = GetTransformationHandler().CalcTransformations(*(fEventCollections.at(idx)),kTRUE);
3363 }
3364 return *(fEventCollections.at(idx));
3365}
3366
3367////////////////////////////////////////////////////////////////////////////////
3368/// calculates the TMVA version string from the training version code on the fly
3369
3371{
3372 UInt_t a = GetTrainingTMVAVersionCode() & 0xff0000; a>>=16;
3373 UInt_t b = GetTrainingTMVAVersionCode() & 0x00ff00; b>>=8;
3374 UInt_t c = GetTrainingTMVAVersionCode() & 0x0000ff;
3375
3376 return TString(Form("%i.%i.%i",a,b,c));
3377}
3378
3379////////////////////////////////////////////////////////////////////////////////
3380/// calculates the ROOT version string from the training version code on the fly
3381
3383{
3384 UInt_t a = GetTrainingROOTVersionCode() & 0xff0000; a>>=16;
3385 UInt_t b = GetTrainingROOTVersionCode() & 0x00ff00; b>>=8;
3386 UInt_t c = GetTrainingROOTVersionCode() & 0x0000ff;
3387
3388 return TString(Form("%i.%02i/%02i",a,b,c));
3389}
3390
3391////////////////////////////////////////////////////////////////////////////////
3392
3394 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
3395 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
3396
3397 if (mvaRes != NULL) {
3398 TH1D *mva_s = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_S"));
3399 TH1D *mva_b = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_B"));
3400 TH1D *mva_s_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_S"));
3401 TH1D *mva_b_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_B"));
3402
3403 if ( !mva_s || !mva_b || !mva_s_tr || !mva_b_tr) return -1;
3404
3405 if (SorB == 's' || SorB == 'S')
3406 return mva_s->KolmogorovTest( mva_s_tr, opt.Data() );
3407 else
3408 return mva_b->KolmogorovTest( mva_b_tr, opt.Data() );
3409 }
3410 return -1;
3411}
const Bool_t Use_Splines_for_Eff_
Definition: MethodBase.cxx:130
const Int_t NBIN_HIST_HIGH
Definition: MethodBase.cxx:133
#define d(i)
Definition: RSha256.hxx:102
#define c(i)
Definition: RSha256.hxx:101
#define s1(x)
Definition: RSha256.hxx:91
#define ROOT_VERSION_CODE
Definition: RVersion.h:21
bool Bool_t
Definition: RtypesCore.h:63
int Int_t
Definition: RtypesCore.h:45
char Char_t
Definition: RtypesCore.h:37
const Bool_t kFALSE
Definition: RtypesCore.h:101
float Float_t
Definition: RtypesCore.h:57
double Double_t
Definition: RtypesCore.h:59
long long Long64_t
Definition: RtypesCore.h:80
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassImp(name)
Definition: Rtypes.h:375
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t b
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char y2
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Option_t Option_t TPoint TPoint const char y1
char name[80]
Definition: TGX11.cxx:110
float xmin
Definition: THbookFile.cxx:95
float xmax
Definition: THbookFile.cxx:95
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:23
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition: TString.cxx:2456
R__EXTERN TSystem * gSystem
Definition: TSystem.h:560
#define TMVA_VERSION_CODE
Definition: Version.h:47
Class to manage histogram axis.
Definition: TAxis.h:30
Double_t GetXmax() const
Definition: TAxis.h:135
Double_t GetXmin() const
Definition: TAxis.h:134
Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0) override
Write all objects in this collection.
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:184
This class stores the date and time with a precision of one second in an unsigned 32 bit word (950130...
Definition: TDatime.h:37
const char * AsString() const
Return the date & time as a string (ctime() format).
Definition: TDatime.cxx:102
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition: TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:407
virtual TDirectory * mkdir(const char *name, const char *title="", Bool_t returnExistingDirectory=kFALSE)
Create a sub-directory "a" or a hierarchy of sub-directories "a/b/c/...".
virtual Bool_t cd(const char *path=nullptr)
Change current directory to "this" directory.
Definition: TDirectory.cxx:504
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:4053
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:916
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
1-D histogram with a double per channel (see TH1 documentation)}
Definition: TH1.h:620
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:577
TH1 is the base class of all histogram classes in ROOT.
Definition: TH1.h:58
virtual Double_t GetBinCenter(Int_t bin) const
Return bin center for 1D histogram.
Definition: TH1.cxx:9008
virtual void AddBinContent(Int_t bin)
Increment bin content by 1.
Definition: TH1.cxx:1242
virtual Double_t GetMean(Int_t axis=1) const
For axis = 1,2 or 3 returns the mean value of the histogram along X,Y or Z axis.
Definition: TH1.cxx:7446
virtual void SetXTitle(const char *title)
Definition: TH1.h:415
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1267
TAxis * GetXaxis()
Definition: TH1.h:322
virtual Double_t GetMaximum(Double_t maxval=FLT_MAX) const
Return maximum value smaller than maxval of bins in the range, unless the value has been overridden b...
Definition: TH1.cxx:8412
virtual Int_t GetNbinsX() const
Definition: TH1.h:295
virtual Int_t Fill(Double_t x)
Increment bin with abscissa X by 1.
Definition: TH1.cxx:3338
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:9089
virtual Int_t GetMaximumBin() const
Return location of bin with maximum value in the range.
Definition: TH1.cxx:8444
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition: TH1.cxx:5025
virtual Int_t GetQuantiles(Int_t nprobSum, Double_t *q, const Double_t *probSum=nullptr)
Compute Quantiles for this histogram Quantile x_q of a probability distribution Function F is defined...
Definition: TH1.cxx:4575
virtual void SetYTitle(const char *title)
Definition: TH1.h:416
virtual void Scale(Double_t c1=1, Option_t *option="")
Multiply this histogram by a constant c1.
Definition: TH1.cxx:6586
virtual Int_t FindBin(Double_t x, Double_t y=0, Double_t z=0)
Return Global bin number corresponding to x,y,z.
Definition: TH1.cxx:3668
virtual Double_t KolmogorovTest(const TH1 *h2, Option_t *option="") const
Statistical test of compatibility in shape between this histogram and h2, using Kolmogorov test.
Definition: TH1.cxx:8087
virtual void Sumw2(Bool_t flag=kTRUE)
Create structure to store sum of squares of weights.
Definition: TH1.cxx:8887
static Bool_t AddDirectoryStatus()
Static function: cannot be inlined on Windows/NT.
Definition: TH1.cxx:735
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:257
Int_t Fill(Double_t) override
Invalid Fill method.
Definition: TH2.cxx:347
A doubly linked list.
Definition: TList.h:38
TObject * At(Int_t idx) const override
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:357
Class that contains all the information of a class.
Definition: ClassInfo.h:49
UInt_t GetNumber() const
Definition: ClassInfo.h:65
TString fWeightFileExtension
Definition: Config.h:125
VariablePlotting & GetVariablePlotting()
Definition: Config.h:97
class TMVA::Config::VariablePlotting fVariablePlotting
IONames & GetIONames()
Definition: Config.h:98
MsgLogger * fLogger
! message logger
Definition: Configurable.h:128
Class that contains all the data information.
Definition: DataSetInfo.h:62
Class that contains all the data information.
Definition: DataSet.h:58
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:389
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:399
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
static void SetIgnoreNegWeightsInTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:408
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
Definition: MethodBase.cxx:169
IPythonInteractive()
standard constructor
Definition: MethodBase.cxx:146
~IPythonInteractive()
standard destructor
Definition: MethodBase.cxx:154
void ClearGraphs()
This function sets the point number to 0 for all graphs.
Definition: MethodBase.cxx:193
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
Definition: MethodBase.cxx:207
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TDirectory * MethodBaseDir() const
returns the ROOT directory where all instances of the corresponding MVA method are stored
virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X")
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
Definition: MethodBase.cxx:237
virtual Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr)=0
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void ReadClassesFromXML(void *clsnode)
read number of classes from XML
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
void DeclareBaseOptions()
define the options (their key words) that can be set in the option string here the options valid for ...
Definition: MethodBase.cxx:509
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:991
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:596
virtual Double_t GetSignificance() const
compute significance of mean difference
virtual Double_t GetProba(const Event *ev)
const char * GetName() const
Definition: MethodBase.h:334
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
const std::vector< TMVA::Event * > & GetEventCollection(Types::ETreeType type)
returns the event collection (i.e.
virtual std::vector< Double_t > GetDataMvaValues(DataSet *data=nullptr, Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the given Data type
Definition: MethodBase.cxx:939
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:406
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
void AddInfoItem(void *gi, const TString &name, const TString &value) const
xml writing
virtual void AddClassifierOutputProb(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:950
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
TString GetTrainingTMVAVersionString() const
calculates the TMVA version string from the training version code on the fly
void Statistics(Types::ETreeType treeType, const TString &theVarName, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &)
calculates rms,mean, xmin, xmax of the event variable this can be either done for the variables as th...
Bool_t GetLine(std::istream &fin, char *buf)
reads one line from the input stream checks for certain keywords and interprets the line if keywords ...
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:423
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodBase.cxx:897
virtual Bool_t IsSignalLike()
uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation) for...
Definition: MethodBase.cxx:854
virtual ~MethodBase()
destructor
Definition: MethodBase.cxx:364
virtual Double_t GetMaximumSignificance(Double_t SignalEvents, Double_t BackgroundEvents, Double_t &optimal_significance_value) const
plot significance, , curve for given number of signal and background events; returns cut for maximum ...
virtual Double_t GetTrainingEfficiency(const TString &)
void SetWeightFileName(TString)
set the weight file name (depreciated)
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
TString GetWeightFileName() const
retrieve weight file name
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
virtual void WriteMonitoringHistosToFile() const
write special monitoring histograms to file dummy implementation here --------------—
virtual void AddRegressionOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:744
void InitBase()
default initialization called by all constructors
Definition: MethodBase.cxx:441
virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t &stddev, Double_t &stddev90Percent) const
Definition: MethodBase.cxx:724
void ReadStateFromXMLString(const char *xmlstr)
for reading from memory
void CreateMVAPdfs()
Create PDFs of the MVA output variables.
TString GetTrainingROOTVersionString() const
calculates the ROOT version string from the training version code on the fly
virtual Double_t GetValueForRoot(Double_t)
returns efficiency as function of cut
void ReadStateFromFile()
Function to write options and weights to file.
void WriteVarsToStream(std::ostream &tf, const TString &prefix="") const
write the list of variables (name, min, max) for a given data transformation method to the stream
void ReadVarsFromStream(std::istream &istr)
Read the variables (name, min, max) for a given data transformation method from the stream.
void ReadSpectatorsFromXML(void *specnode)
read spectator info from XML
void SetTestvarName(const TString &v="")
Definition: MethodBase.h:341
void ReadVariablesFromXML(void *varnode)
read variable info from XML
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:623
virtual std::vector< Float_t > GetMulticlassTrainingEfficiency(std::vector< std::vector< Float_t > > &purity)
void WriteStateToStream(std::ostream &tf) const
general method used in writing the header of the weight files where the used variables,...
virtual Double_t GetRarity(Double_t mvaVal, Types::ESBType reftype=Types::kBackground) const
compute rarity:
virtual void SetTuneParameters(std::map< TString, Double_t > tuneParameters)
set the tuning parameters according to the argument This is just a dummy .
Definition: MethodBase.cxx:644
void ReadStateFromStream(std::istream &tf)
read the header from the weight files of the different MVA methods
void AddVarsXMLTo(void *parent) const
write variable info to XML
void AddTargetsXMLTo(void *parent) const
write target info to XML
void ReadTargetsFromXML(void *tarnode)
read target info from XML
void ProcessBaseOptions()
the option string is decoded, for available options see "DeclareOptions"
Definition: MethodBase.cxx:540
void ReadStateFromXML(void *parent)
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:836
void WriteStateToFile() const
write options and weights to file note that each one text file for the main configuration information...
void AddClassesXMLTo(void *parent) const
write class info to XML
virtual void AddClassifierOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:868
void AddSpectatorsXMLTo(void *parent) const
write spectator info to XML
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void AddMulticlassOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:794
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:433
void SetSource(const std::string &source)
Definition: MsgLogger.h:68
PDF wrapper for histograms; uses user-defined spline interpolation.
Definition: PDF.h:63
Double_t GetXmin() const
Definition: PDF.h:104
Double_t GetXmax() const
Definition: PDF.h:105
Double_t GetVal(Double_t x) const
returns value PDF(x)
Definition: PDF.cxx:701
@ kSpline3
Definition: PDF.h:70
@ kSpline2
Definition: PDF.h:70
Double_t GetIntegral(Double_t xmin, Double_t xmax)
computes PDF integral within given ranges
Definition: PDF.cxx:654
Class that is the base-class for a vector of result.
std::vector< Float_t > * GetValueVector()
void SetValue(Float_t value, Int_t ievt, Bool_t type)
set MVA response
Class which takes the results of a multiclass classification.
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
Float_t GetAchievablePur(UInt_t cls)
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
Float_t GetAchievableEff(UInt_t cls)
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
Class that is the base-class for a vector of result.
Class that is the base-class for a vector of result.
Definition: Results.h:57
Bool_t DoesExist(const TString &alias) const
Returns true if there is an object stored in the result for a given alias, false otherwise.
Definition: Results.cxx:127
void Store(TObject *obj, const char *alias=nullptr)
Definition: Results.cxx:86
TH1 * GetHist(const TString &alias) const
Definition: Results.cxx:136
TList * GetStorage() const
Definition: Results.h:73
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
Definition: RootFinder.h:48
Double_t Root(Double_t refValue)
Root finding using Brents algorithm; taken from CERNLIB function RZERO.
Definition: RootFinder.cxx:71
Linear interpolation of TGraph.
Definition: TSpline1.h:43
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
Double_t ElapsedSeconds(void)
computes elapsed tim in seconds
Definition: Timer.cxx:137
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:146
void DrawProgressBar(Int_t, const TString &comment="")
draws progress bar in color or B&W caution:
Definition: Timer.cxx:202
void ComputeStat(const std::vector< TMVA::Event * > &, std::vector< Float_t > *, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Int_t signalClass, Bool_t norm=kFALSE)
sanity check
Definition: Tools.cxx:202
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition: Tools.cxx:401
Double_t GetSeparation(TH1 *S, TH1 *B) const
compute "separation" defined as
Definition: Tools.cxx:121
Double_t GetMutualInformation(const TH2F &)
Mutual Information method for non-linear correlations estimates in 2D histogram Author: Moritz Backes...
Definition: Tools.cxx:589
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:828
TXMLEngine & xmlengine()
Definition: Tools.h:262
Bool_t CheckSplines(const TH1 *, const TSpline *)
check quality of splining by comparing splines and histograms in each bin
Definition: Tools.cxx:479
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition: Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:347
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:383
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition: Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition: Tools.cxx:1162
Singleton class for Global types used by TMVA.
Definition: Types.h:71
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition: Types.h:135
@ kBackground
Definition: Types.h:136
@ kLikelihood
Definition: Types.h:79
@ kHMatrix
Definition: Types.h:81
EAnalysisType
Definition: Types.h:126
@ kMulticlass
Definition: Types.h:129
@ kNoAnalysisType
Definition: Types.h:130
@ kClassification
Definition: Types.h:127
@ kMaxAnalysisType
Definition: Types.h:131
@ kRegression
Definition: Types.h:128
@ kTraining
Definition: Types.h:143
@ kTesting
Definition: Types.h:144
Linear interpolation class.
Gaussian Transformation of input variables.
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
void ReadFromXML(void *varnode)
read VariableInfo from stream
const TString & GetExpression() const
Definition: VariableInfo.h:57
char GetVarType() const
Definition: VariableInfo.h:61
void ReadFromStream(std::istream &istr)
read VariableInfo from stream
void AddToXML(void *varnode)
write class to XML
void SetExternalLink(void *p)
Definition: VariableInfo.h:75
void * GetExternalLink() const
Definition: VariableInfo.h:83
void BuildTransformationFromVarInfo(const std::vector< TMVA::VariableInfo > &var)
this method is only used when building a normalization transformation from old text files in this cas...
Linear interpolation class.
Linear interpolation class.
virtual void ReadTransformationFromStream(std::istream &istr, const TString &classname="")=0
@ kDEBUG
Definition: Types.h:56
@ kVERBOSE
Definition: Types.h:57
@ kHEADER
Definition: Types.h:63
@ kERROR
Definition: Types.h:60
@ kINFO
Definition: Types.h:58
@ kWARNING
Definition: Types.h:59
@ kFATAL
Definition: Types.h:61
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:34
const char * GetName() const override
Returns name of object.
Definition: TNamed.h:47
Collectable string class.
Definition: TObjString.h:28
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TObject.cxx:875
Basic string class.
Definition: TString.h:136
Ssiz_t Length() const
Definition: TString.h:410
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1159
Int_t Atoi() const
Return integer value of string.
Definition: TString.cxx:1955
Bool_t EndsWith(const char *pat, ECaseCompare cmp=kExact) const
Return true if string ends with the specified string.
Definition: TString.cxx:2211
TSubString Strip(EStripType s=kTrailing, char c=' ') const
Return a substring of self stripped at beginning and/or end.
Definition: TString.cxx:1140
const char * Data() const
Definition: TString.h:369
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:692
@ kLeading
Definition: TString.h:267
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:925
Bool_t IsNull() const
Definition: TString.h:407
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2345
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:639
virtual const char * GetBuildNode() const
Return the build node name.
Definition: TSystem.cxx:3900
virtual int mkdir(const char *name, Bool_t recursive=kFALSE)
Make a file system directory.
Definition: TSystem.cxx:909
virtual const char * WorkingDirectory()
Return working directory.
Definition: TSystem.cxx:874
virtual UserGroup_t * GetUserInfo(Int_t uid)
Returns all user info in the UserGroup_t structure.
Definition: TSystem.cxx:1602
void SaveDoc(XMLDocPointer_t xmldoc, const char *filename, Int_t layout=1)
store document content to file if layout<=0, no any spaces or newlines will be placed between xmlnode...
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
XMLDocPointer_t NewDoc(const char *version="1.0")
creates new xml document with provided version
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLDocPointer_t ParseString(const char *xmlstring)
parses content of string and tries to produce xml structures
void DocSetRootElement(XMLDocPointer_t xmldoc, XMLNodePointer_t xmlnode)
set main (root) node for document
TLine * line
Double_t x[n]
Definition: legend1.C:17
TH1F * h1
Definition: legend1.C:5
VecExpr< UnaryOp< Sqrt< T >, VecExpr< A, T, D >, T >, T, D > sqrt(const VecExpr< A, T, D > &rhs)
bool BeginsWith(const std::string &theString, const std::string &theSubstring)
void Init(TClassEdit::TInterpreterLookupHelper *helper)
Definition: TClassEdit.cxx:171
static constexpr double s
static constexpr double m2
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:342
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:148
Short_t Max(Short_t a, Short_t b)
Returns the largest of a and b.
Definition: TMathBase.h:250
Double_t Log(Double_t x)
Returns the natural logarithm of x.
Definition: TMath.h:754
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition: TMath.h:660
Short_t Min(Short_t a, Short_t b)
Returns the smallest of a and b.
Definition: TMathBase.h:198
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition: TMathBase.h:123
TString fUser
Definition: TSystem.h:141
TArc a
Definition: textangle.C:12
double epsilon
Definition: triangle.c:618