Logo ROOT  
Reference Guide
MethodBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
16 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * U. of Victoria, Canada *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 * *
32 **********************************************************************************/
33
34/*! \class TMVA::MethodBase
35\ingroup TMVA
36
37 Virtual base Class for all MVA method
38
39 MethodBase hosts several specific evaluation methods.
40
41 The kind of MVA that provides optimal performance in an analysis strongly
42 depends on the particular application. The evaluation factory provides a
43 number of numerical benchmark results to directly assess the performance
44 of the MVA training on the independent test sample. These are:
45
46 - The _signal efficiency_ at three representative background efficiencies
47 (which is 1 &minus; rejection).
48 - The _significance_ of an MVA estimator, defined by the difference
49 between the MVA mean values for signal and background, divided by the
50 quadratic sum of their root mean squares.
51 - The _separation_ of an MVA _x_, defined by the integral
52 \f[
53 \frac{1}{2} \int \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx
54 \f]
55 where
56 \f$ S(x) \f$ and \f$ B(x) \f$ are the signal and background distributions,
57 respectively. The separation is zero for identical signal and background MVA
58 shapes, and it is one for disjunctive shapes.
59 - The average, \f$ \int x \mu (S(x)) dx \f$, of the signal \f$ \mu_{transform} \f$.
60 The \f$ \mu_{transform} \f$ of an MVA denotes the transformation that yields
61 a uniform background distribution. In this way, the signal distributions
62 \f$ S(x) \f$ can be directly compared among the various MVAs. The stronger
63 \f$ S(x) \f$ peaks towards one, the better is the discrimination of the MVA.
64 The \f$ \mu_{transform} \f$ is
65 [documented here](http://tel.ccsd.cnrs.fr/documents/archives0/00/00/29/91/index_fr.html).
66
67 The MVA standard output also prints the linear correlation coefficients between
68 signal and background, which can be useful to eliminate variables that exhibit too
69 strong correlations.
70*/
71
72#include "TMVA/MethodBase.h"
73
74#include "TMVA/Config.h"
75#include "TMVA/Configurable.h"
76#include "TMVA/DataSetInfo.h"
77#include "TMVA/DataSet.h"
78#include "TMVA/Factory.h"
79#include "TMVA/IMethod.h"
80#include "TMVA/MsgLogger.h"
81#include "TMVA/PDF.h"
82#include "TMVA/Ranking.h"
83#include "TMVA/DataLoader.h"
84#include "TMVA/Tools.h"
85#include "TMVA/Results.h"
89#include "TMVA/RootFinder.h"
90#include "TMVA/Timer.h"
91#include "TMVA/TSpline1.h"
92#include "TMVA/Types.h"
96#include "TMVA/VariableInfo.h"
100#include "TMVA/Version.h"
101
102#include "TROOT.h"
103#include "TSystem.h"
104#include "TObjString.h"
105#include "TQObject.h"
106#include "TSpline.h"
107#include "TMatrix.h"
108#include "TMath.h"
109#include "TH1F.h"
110#include "TH2F.h"
111#include "TFile.h"
112#include "TGraph.h"
113#include "TXMLEngine.h"
114
115#include <iomanip>
116#include <iostream>
117#include <fstream>
118#include <sstream>
119#include <cstdlib>
120#include <algorithm>
121#include <limits>
122
123
125
126using std::endl;
127using std::atof;
128
129//const Int_t MethodBase_MaxIterations_ = 200;
131
132//const Int_t NBIN_HIST_PLOT = 100;
133const Int_t NBIN_HIST_HIGH = 10000;
134
135#ifdef _WIN32
136/* Disable warning C4355: 'this' : used in base member initializer list */
137#pragma warning ( disable : 4355 )
138#endif
139
140
141#include "TMultiGraph.h"
142
143////////////////////////////////////////////////////////////////////////////////
144/// standard constructor
145
147{
148 fNumGraphs = 0;
149 fIndex = 0;
150}
151
152////////////////////////////////////////////////////////////////////////////////
153/// standard destructor
155{
156 if (fMultiGraph){
157 delete fMultiGraph;
158 fMultiGraph = nullptr;
159 }
160 return;
161}
162
163////////////////////////////////////////////////////////////////////////////////
164/// This function gets some title and it creates a TGraph for every title.
165/// It also sets up the style for every TGraph. All graphs are added to a single TMultiGraph.
166///
167/// \param[in] graphTitles vector of titles
168
169void TMVA::IPythonInteractive::Init(std::vector<TString>& graphTitles)
170{
171 if (fNumGraphs!=0){
172 std::cerr << kERROR << "IPythonInteractive::Init: already initialized..." << std::endl;
173 return;
174 }
175 Int_t color = 2;
176 for(auto& title : graphTitles){
177 fGraphs.push_back( new TGraph() );
178 fGraphs.back()->SetTitle(title);
179 fGraphs.back()->SetName(title);
180 fGraphs.back()->SetFillColor(color);
181 fGraphs.back()->SetLineColor(color);
182 fGraphs.back()->SetMarkerColor(color);
183 fMultiGraph->Add(fGraphs.back());
184 color += 2;
185 fNumGraphs += 1;
186 }
187 return;
188}
189
190////////////////////////////////////////////////////////////////////////////////
191/// This function sets the point number to 0 for all graphs.
192
194{
195 for(Int_t i=0; i<fNumGraphs; i++){
196 fGraphs[i]->Set(0);
197 }
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// This function is used only in 2 TGraph case, and it will add new data points to graphs.
202///
203/// \param[in] x the x coordinate
204/// \param[in] y1 the y coordinate for the first TGraph
205/// \param[in] y2 the y coordinate for the second TGraph
206
208{
209 fGraphs[0]->Set(fIndex+1);
210 fGraphs[1]->Set(fIndex+1);
211 fGraphs[0]->SetPoint(fIndex, x, y1);
212 fGraphs[1]->SetPoint(fIndex, x, y2);
213 fIndex++;
214 return;
215}
216
217////////////////////////////////////////////////////////////////////////////////
218/// This function can add data points to as many TGraphs as we have.
219///
220/// \param[in] dat vector of data points. The dat[0] contains the x coordinate,
221/// dat[1] contains the y coordinate for first TGraph, dat[2] for second, ...
222
223void TMVA::IPythonInteractive::AddPoint(std::vector<Double_t>& dat)
224{
225 for(Int_t i=0; i<fNumGraphs;i++){
226 fGraphs[i]->Set(fIndex+1);
227 fGraphs[i]->SetPoint(fIndex, dat[0], dat[i+1]);
228 }
229 fIndex++;
230 return;
231}
232
233
234////////////////////////////////////////////////////////////////////////////////
235/// standard constructor
236
238 Types::EMVA methodType,
239 const TString& methodTitle,
240 DataSetInfo& dsi,
241 const TString& theOption) :
242 IMethod(),
243 Configurable ( theOption ),
244 fTmpEvent ( 0 ),
245 fRanking ( 0 ),
246 fInputVars ( 0 ),
247 fAnalysisType ( Types::kNoAnalysisType ),
248 fRegressionReturnVal ( 0 ),
249 fMulticlassReturnVal ( 0 ),
250 fDataSetInfo ( dsi ),
251 fSignalReferenceCut ( 0.5 ),
252 fSignalReferenceCutOrientation( 1. ),
253 fVariableTransformType ( Types::kSignal ),
254 fJobName ( jobName ),
255 fMethodName ( methodTitle ),
256 fMethodType ( methodType ),
257 fTestvar ( "" ),
258 fTMVATrainingVersion ( TMVA_VERSION_CODE ),
259 fROOTTrainingVersion ( ROOT_VERSION_CODE ),
260 fConstructedFromWeightFile ( kFALSE ),
261 fBaseDir ( 0 ),
262 fMethodBaseDir ( 0 ),
263 fFile ( 0 ),
264 fSilentFile (kFALSE),
265 fModelPersistence (kTRUE),
266 fWeightFile ( "" ),
267 fEffS ( 0 ),
268 fDefaultPDF ( 0 ),
269 fMVAPdfS ( 0 ),
270 fMVAPdfB ( 0 ),
271 fSplS ( 0 ),
272 fSplB ( 0 ),
273 fSpleffBvsS ( 0 ),
274 fSplTrainS ( 0 ),
275 fSplTrainB ( 0 ),
276 fSplTrainEffBvsS ( 0 ),
277 fVarTransformString ( "None" ),
278 fTransformationPointer ( 0 ),
279 fTransformation ( dsi, methodTitle ),
280 fVerbose ( kFALSE ),
281 fVerbosityLevelString ( "Default" ),
282 fHelp ( kFALSE ),
283 fHasMVAPdfs ( kFALSE ),
284 fIgnoreNegWeightsInTraining( kFALSE ),
285 fSignalClass ( 0 ),
286 fBackgroundClass ( 0 ),
287 fSplRefS ( 0 ),
288 fSplRefB ( 0 ),
289 fSplTrainRefS ( 0 ),
290 fSplTrainRefB ( 0 ),
291 fSetupCompleted (kFALSE)
292{
295
296// // default extension for weight files
297}
298
299////////////////////////////////////////////////////////////////////////////////
300/// constructor used for Testing + Application of the MVA,
301/// only (no training), using given WeightFiles
302
304 DataSetInfo& dsi,
305 const TString& weightFile ) :
306 IMethod(),
307 Configurable(""),
308 fTmpEvent ( 0 ),
309 fRanking ( 0 ),
310 fInputVars ( 0 ),
311 fAnalysisType ( Types::kNoAnalysisType ),
312 fRegressionReturnVal ( 0 ),
313 fMulticlassReturnVal ( 0 ),
314 fDataSetInfo ( dsi ),
315 fSignalReferenceCut ( 0.5 ),
316 fVariableTransformType ( Types::kSignal ),
317 fJobName ( "" ),
318 fMethodName ( "MethodBase" ),
319 fMethodType ( methodType ),
320 fTestvar ( "" ),
321 fTMVATrainingVersion ( 0 ),
322 fROOTTrainingVersion ( 0 ),
323 fConstructedFromWeightFile ( kTRUE ),
324 fBaseDir ( 0 ),
325 fMethodBaseDir ( 0 ),
326 fFile ( 0 ),
327 fSilentFile (kFALSE),
328 fModelPersistence (kTRUE),
329 fWeightFile ( weightFile ),
330 fEffS ( 0 ),
331 fDefaultPDF ( 0 ),
332 fMVAPdfS ( 0 ),
333 fMVAPdfB ( 0 ),
334 fSplS ( 0 ),
335 fSplB ( 0 ),
336 fSpleffBvsS ( 0 ),
337 fSplTrainS ( 0 ),
338 fSplTrainB ( 0 ),
339 fSplTrainEffBvsS ( 0 ),
340 fVarTransformString ( "None" ),
341 fTransformationPointer ( 0 ),
342 fTransformation ( dsi, "" ),
343 fVerbose ( kFALSE ),
344 fVerbosityLevelString ( "Default" ),
345 fHelp ( kFALSE ),
346 fHasMVAPdfs ( kFALSE ),
347 fIgnoreNegWeightsInTraining( kFALSE ),
348 fSignalClass ( 0 ),
349 fBackgroundClass ( 0 ),
350 fSplRefS ( 0 ),
351 fSplRefB ( 0 ),
352 fSplTrainRefS ( 0 ),
353 fSplTrainRefB ( 0 ),
354 fSetupCompleted (kFALSE)
355{
357// // constructor used for Testing + Application of the MVA,
358// // only (no training), using given WeightFiles
359}
360
361////////////////////////////////////////////////////////////////////////////////
362/// destructor
363
365{
366 // destructor
367 if (!fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling destructor of method which got never setup" << Endl;
368
369 // destructor
370 if (fInputVars != 0) { fInputVars->clear(); delete fInputVars; }
371 if (fRanking != 0) delete fRanking;
372
373 // PDFs
374 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
375 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
376 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
377
378 // Splines
379 if (fSplS) { delete fSplS; fSplS = 0; }
380 if (fSplB) { delete fSplB; fSplB = 0; }
381 if (fSpleffBvsS) { delete fSpleffBvsS; fSpleffBvsS = 0; }
382 if (fSplRefS) { delete fSplRefS; fSplRefS = 0; }
383 if (fSplRefB) { delete fSplRefB; fSplRefB = 0; }
384 if (fSplTrainRefS) { delete fSplTrainRefS; fSplTrainRefS = 0; }
385 if (fSplTrainRefB) { delete fSplTrainRefB; fSplTrainRefB = 0; }
386 if (fSplTrainEffBvsS) { delete fSplTrainEffBvsS; fSplTrainEffBvsS = 0; }
387
388 for (Int_t i = 0; i < 2; i++ ) {
389 if (fEventCollections.at(i)) {
390 for (std::vector<Event*>::const_iterator it = fEventCollections.at(i)->begin();
391 it != fEventCollections.at(i)->end(); ++it) {
392 delete (*it);
393 }
394 delete fEventCollections.at(i);
395 fEventCollections.at(i) = 0;
396 }
397 }
398
399 if (fRegressionReturnVal) delete fRegressionReturnVal;
400 if (fMulticlassReturnVal) delete fMulticlassReturnVal;
401}
402
403////////////////////////////////////////////////////////////////////////////////
404/// setup of methods
405
407{
408 // setup of methods
409
410 if (fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling SetupMethod for the second time" << Endl;
411 InitBase();
412 DeclareBaseOptions();
413 Init();
414 DeclareOptions();
415 fSetupCompleted = kTRUE;
416}
417
418////////////////////////////////////////////////////////////////////////////////
419/// process all options
420/// the "CheckForUnusedOptions" is done in an independent call, since it may be overridden by derived class
421/// (sometimes, eg, fitters are used which can only be implemented during training phase)
422
424{
425 ProcessBaseOptions();
426 ProcessOptions();
427}
428
429////////////////////////////////////////////////////////////////////////////////
430/// check may be overridden by derived class
431/// (sometimes, eg, fitters are used which can only be implemented during training phase)
432
434{
435 CheckForUnusedOptions();
436}
437
438////////////////////////////////////////////////////////////////////////////////
439/// default initialization called by all constructors
440
442{
443 SetConfigDescription( "Configuration options for classifier architecture and tuning" );
444
446 fNbinsMVAoutput = gConfig().fVariablePlotting.fNbinsMVAoutput;
447 fNbinsH = NBIN_HIST_HIGH;
448
449 fSplTrainS = 0;
450 fSplTrainB = 0;
451 fSplTrainEffBvsS = 0;
452 fMeanS = -1;
453 fMeanB = -1;
454 fRmsS = -1;
455 fRmsB = -1;
456 fXmin = DBL_MAX;
457 fXmax = -DBL_MAX;
458 fTxtWeightsOnly = kTRUE;
459 fSplRefS = 0;
460 fSplRefB = 0;
461
462 fTrainTime = -1.;
463 fTestTime = -1.;
464
465 fRanking = 0;
466
467 // temporary until the move to DataSet is complete
468 fInputVars = new std::vector<TString>;
469 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
470 fInputVars->push_back(DataInfo().GetVariableInfo(ivar).GetLabel());
471 }
472 fRegressionReturnVal = 0;
473 fMulticlassReturnVal = 0;
474
475 fEventCollections.resize( 2 );
476 fEventCollections.at(0) = 0;
477 fEventCollections.at(1) = 0;
478
479 // retrieve signal and background class index
480 if (DataInfo().GetClassInfo("Signal") != 0) {
481 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
482 }
483 if (DataInfo().GetClassInfo("Background") != 0) {
484 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
485 }
486
487 SetConfigDescription( "Configuration options for MVA method" );
488 SetConfigName( TString("Method") + GetMethodTypeName() );
489}
490
491////////////////////////////////////////////////////////////////////////////////
492/// define the options (their key words) that can be set in the option string
493/// here the options valid for ALL MVA methods are declared.
494///
495/// know options:
496///
497/// - VariableTransform=None,Decorrelated,PCA to use transformed variables
498/// instead of the original ones
499/// - VariableTransformType=Signal,Background which decorrelation matrix to use
500/// in the method. Only the Likelihood
501/// Method can make proper use of independent
502/// transformations of signal and background
503/// - fNbinsMVAPdf = 50 Number of bins used to create a PDF of MVA
504/// - fNsmoothMVAPdf = 2 Number of times a histogram is smoothed before creating the PDF
505/// - fHasMVAPdfs create PDFs for the MVA outputs
506/// - V for Verbose output (!V) for non verbos
507/// - H for Help message
508
510{
511 DeclareOptionRef( fVerbose, "V", "Verbose output (short form of \"VerbosityLevel\" below - overrides the latter one)" );
512
513 DeclareOptionRef( fVerbosityLevelString="Default", "VerbosityLevel", "Verbosity level" );
514 AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
515 AddPreDefVal( TString("Debug") );
516 AddPreDefVal( TString("Verbose") );
517 AddPreDefVal( TString("Info") );
518 AddPreDefVal( TString("Warning") );
519 AddPreDefVal( TString("Error") );
520 AddPreDefVal( TString("Fatal") );
521
522 // If True (default): write all training results (weights) as text files only;
523 // if False: write also in ROOT format (not available for all methods - will abort if not
524 fTxtWeightsOnly = kTRUE; // OBSOLETE !!!
525 fNormalise = kFALSE; // OBSOLETE !!!
526
527 DeclareOptionRef( fVarTransformString, "VarTransform", "List of variable transformations performed before training, e.g., \"D_Background,P_Signal,G,N_AllClasses\" for: \"Decorrelation, PCA-transformation, Gaussianisation, Normalisation, each for the given class of events ('AllClasses' denotes all events of all classes, if no class indication is given, 'All' is assumed)\"" );
528
529 DeclareOptionRef( fHelp, "H", "Print method-specific help message" );
530
531 DeclareOptionRef( fHasMVAPdfs, "CreateMVAPdfs", "Create PDFs for classifier outputs (signal and background)" );
532
533 DeclareOptionRef( fIgnoreNegWeightsInTraining, "IgnoreNegWeightsInTraining",
534 "Events with negative weights are ignored in the training (but are included for testing and performance evaluation)" );
535}
536
537////////////////////////////////////////////////////////////////////////////////
538/// the option string is decoded, for available options see "DeclareOptions"
539
541{
542 if (HasMVAPdfs()) {
543 // setting the default bin num... maybe should be static ? ==> Please no static (JS)
544 // You can't use the logger in the constructor!!! Log() << kINFO << "Create PDFs" << Endl;
545 // reading every PDF's definition and passing the option string to the next one to be read and marked
546 fDefaultPDF = new PDF( TString(GetName())+"_PDF", GetOptions(), "MVAPdf" );
547 fDefaultPDF->DeclareOptions();
548 fDefaultPDF->ParseOptions();
549 fDefaultPDF->ProcessOptions();
550 fMVAPdfB = new PDF( TString(GetName())+"_PDFBkg", fDefaultPDF->GetOptions(), "MVAPdfBkg", fDefaultPDF );
551 fMVAPdfB->DeclareOptions();
552 fMVAPdfB->ParseOptions();
553 fMVAPdfB->ProcessOptions();
554 fMVAPdfS = new PDF( TString(GetName())+"_PDFSig", fMVAPdfB->GetOptions(), "MVAPdfSig", fDefaultPDF );
555 fMVAPdfS->DeclareOptions();
556 fMVAPdfS->ParseOptions();
557 fMVAPdfS->ProcessOptions();
558
559 // the final marked option string is written back to the original methodbase
560 SetOptions( fMVAPdfS->GetOptions() );
561 }
562
563 TMVA::CreateVariableTransforms( fVarTransformString,
564 DataInfo(),
565 GetTransformationHandler(),
566 Log() );
567
568 if (!HasMVAPdfs()) {
569 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
570 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
571 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
572 }
573
574 if (fVerbose) { // overwrites other settings
575 fVerbosityLevelString = TString("Verbose");
576 Log().SetMinType( kVERBOSE );
577 }
578 else if (fVerbosityLevelString == "Debug" ) Log().SetMinType( kDEBUG );
579 else if (fVerbosityLevelString == "Verbose" ) Log().SetMinType( kVERBOSE );
580 else if (fVerbosityLevelString == "Info" ) Log().SetMinType( kINFO );
581 else if (fVerbosityLevelString == "Warning" ) Log().SetMinType( kWARNING );
582 else if (fVerbosityLevelString == "Error" ) Log().SetMinType( kERROR );
583 else if (fVerbosityLevelString == "Fatal" ) Log().SetMinType( kFATAL );
584 else if (fVerbosityLevelString != "Default" ) {
585 Log() << kFATAL << "<ProcessOptions> Verbosity level type '"
586 << fVerbosityLevelString << "' unknown." << Endl;
587 }
588 Event::SetIgnoreNegWeightsInTraining(fIgnoreNegWeightsInTraining);
589}
590
591////////////////////////////////////////////////////////////////////////////////
592/// options that are used ONLY for the READER to ensure backward compatibility
593/// they are hence without any effect (the reader is only reading the training
594/// options that HAD been used at the training of the .xml weight file at hand
595
597{
598 DeclareOptionRef( fNormalise=kFALSE, "Normalise", "Normalise input variables" ); // don't change the default !!!
599 DeclareOptionRef( fUseDecorr=kFALSE, "D", "Use-decorrelated-variables flag" );
600 DeclareOptionRef( fVariableTransformTypeString="Signal", "VarTransformType",
601 "Use signal or background events to derive for variable transformation (the transformation is applied on both types of, course)" );
602 AddPreDefVal( TString("Signal") );
603 AddPreDefVal( TString("Background") );
604 DeclareOptionRef( fTxtWeightsOnly=kTRUE, "TxtWeightFilesOnly", "If True: write all training results (weights) as text files (False: some are written in ROOT format)" );
605 // Why on earth ?? was this here? Was the verbosity level option meant to 'disappear? Not a good idea i think..
606 // DeclareOptionRef( fVerbosityLevelString="Default", "VerboseLevel", "Verbosity level" );
607 // AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
608 // AddPreDefVal( TString("Debug") );
609 // AddPreDefVal( TString("Verbose") );
610 // AddPreDefVal( TString("Info") );
611 // AddPreDefVal( TString("Warning") );
612 // AddPreDefVal( TString("Error") );
613 // AddPreDefVal( TString("Fatal") );
614 DeclareOptionRef( fNbinsMVAPdf = 60, "NbinsMVAPdf", "Number of bins used for the PDFs of classifier outputs" );
615 DeclareOptionRef( fNsmoothMVAPdf = 2, "NsmoothMVAPdf", "Number of smoothing iterations for classifier PDFs" );
616}
617
618
619////////////////////////////////////////////////////////////////////////////////
620/// call the Optimizer with the set of parameters and ranges that
621/// are meant to be tuned.
622
623std::map<TString,Double_t> TMVA::MethodBase::OptimizeTuningParameters(TString /* fomType */ , TString /* fitType */)
624{
625 // this is just a dummy... needs to be implemented for each method
626 // individually (as long as we don't have it automatized via the
627 // configuration string
628
629 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Parameter optimization is not yet implemented for method "
630 << GetName() << Endl;
631 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Currently we need to set hardcoded which parameter is tuned in which ranges"<<Endl;
632
633 std::map<TString,Double_t> tunedParameters;
634 tunedParameters.size(); // just to get rid of "unused" warning
635 return tunedParameters;
636
637}
638
639////////////////////////////////////////////////////////////////////////////////
640/// set the tuning parameters according to the argument
641/// This is just a dummy .. have a look at the MethodBDT how you could
642/// perhaps implement the same thing for the other Classifiers..
643
644void TMVA::MethodBase::SetTuneParameters(std::map<TString,Double_t> /* tuneParameters */)
645{
646}
647
648////////////////////////////////////////////////////////////////////////////////
649
651{
652 Data()->SetCurrentType(Types::kTraining);
653 Event::SetIsTraining(kTRUE); // used to set negative event weights to zero if chosen to do so
654
655 // train the MVA method
656 if (Help()) PrintHelpMessage();
657
658 // all histograms should be created in the method's subdirectory
659 if(!IsSilentFile()) BaseDir()->cd();
660
661 // once calculate all the transformation (e.g. the sequence of Decorr:Gauss:Decorr)
662 // needed for this classifier
663 GetTransformationHandler().CalcTransformations(Data()->GetEventCollection());
664
665 // call training of derived MVA
666 Log() << kDEBUG //<<Form("\tDataset[%s] : ",DataInfo().GetName())
667 << "Begin training" << Endl;
668 Long64_t nEvents = Data()->GetNEvents();
669 Timer traintimer( nEvents, GetName(), kTRUE );
670 Train();
671 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName()
672 << "\tEnd of training " << Endl;
673 SetTrainTime(traintimer.ElapsedSeconds());
674 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
675 << "Elapsed time for training with " << nEvents << " events: "
676 << traintimer.GetElapsedTime() << " " << Endl;
677
678 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
679 << "\tCreate MVA output for ";
680
681 // create PDFs for the signal and background MVA distributions (if required)
682 if (DoMulticlass()) {
683 Log() <<Form("[%s] : ",DataInfo().GetName())<< "Multiclass classification on training sample" << Endl;
684 AddMulticlassOutput(Types::kTraining);
685 }
686 else if (!DoRegression()) {
687
688 Log() <<Form("[%s] : ",DataInfo().GetName())<< "classification on training sample" << Endl;
689 AddClassifierOutput(Types::kTraining);
690 if (HasMVAPdfs()) {
691 CreateMVAPdfs();
692 AddClassifierOutputProb(Types::kTraining);
693 }
694
695 } else {
696
697 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "regression on training sample" << Endl;
698 AddRegressionOutput( Types::kTraining );
699
700 if (HasMVAPdfs() ) {
701 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create PDFs" << Endl;
702 CreateMVAPdfs();
703 }
704 }
705
706 // write the current MVA state into stream
707 // produced are one text file and one ROOT file
708 if (fModelPersistence ) WriteStateToFile();
709
710 // produce standalone make class (presently only supported for classification)
711 if ((!DoRegression()) && (fModelPersistence)) MakeClass();
712
713 // write additional monitoring histograms to main target file (not the weight file)
714 // again, make sure the histograms go into the method's subdirectory
715 if(!IsSilentFile())
716 {
717 BaseDir()->cd();
718 WriteMonitoringHistosToFile();
719 }
720}
721
722////////////////////////////////////////////////////////////////////////////////
723
725{
726 if (!DoRegression()) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Trying to use GetRegressionDeviation() with a classification job" << Endl;
727 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
729 bool truncate = false;
730 TH1F* h1 = regRes->QuadraticDeviation( tgtNum , truncate, 1.);
731 stddev = sqrt(h1->GetMean());
732 truncate = true;
733 Double_t yq[1], xq[]={0.9};
734 h1->GetQuantiles(1,yq,xq);
735 TH1F* h2 = regRes->QuadraticDeviation( tgtNum , truncate, yq[0]);
736 stddev90Percent = sqrt(h2->GetMean());
737 delete h1;
738 delete h2;
739}
740
741////////////////////////////////////////////////////////////////////////////////
742/// prepare tree branch with the method's discriminating variable
743
745{
746 Data()->SetCurrentType(type);
747
748 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
749
750 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), type, Types::kRegression);
751
752 Long64_t nEvents = Data()->GetNEvents();
753
754 // use timer
755 Timer timer( nEvents, GetName(), kTRUE );
756 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
757 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
758
759 regRes->Resize( nEvents );
760
761 // Drawing the progress bar every event was causing a huge slowdown in the evaluation time
762 // So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
763
764 Int_t totalProgressDraws = 100; // total number of times to update the progress bar
765 Int_t drawProgressEvery = 1; // draw every nth event such that we have a total of totalProgressDraws
766 if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
767
768 for (Int_t ievt=0; ievt<nEvents; ievt++) {
769
770 Data()->SetCurrentEvent(ievt);
771 std::vector< Float_t > vals = GetRegressionValues();
772 regRes->SetValue( vals, ievt );
773
774 // Only draw the progress bar once in a while, doing this every event causes the evaluation to be ridiculously slow
775 if(ievt % drawProgressEvery == 0 || ievt==nEvents-1) timer.DrawProgressBar( ievt );
776 }
777
778 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
779 << "Elapsed time for evaluation of " << nEvents << " events: "
780 << timer.GetElapsedTime() << " " << Endl;
781
782 // store time used for testing
784 SetTestTime(timer.ElapsedSeconds());
785
786 TString histNamePrefix(GetTestvarName());
787 histNamePrefix += (type==Types::kTraining?"train":"test");
788 regRes->CreateDeviationHistograms( histNamePrefix );
789}
790
791////////////////////////////////////////////////////////////////////////////////
792/// prepare tree branch with the method's discriminating variable
793
795{
796 Data()->SetCurrentType(type);
797
798 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
799
800 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
801 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in AddMulticlassOutput, exiting."<<Endl;
802
803 Long64_t nEvents = Data()->GetNEvents();
804
805 // use timer
806 Timer timer( nEvents, GetName(), kTRUE );
807
808 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Multiclass evaluation of " << GetMethodName() << " on "
809 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
810
811 resMulticlass->Resize( nEvents );
812 for (Int_t ievt=0; ievt<nEvents; ievt++) {
813 Data()->SetCurrentEvent(ievt);
814 std::vector< Float_t > vals = GetMulticlassValues();
815 resMulticlass->SetValue( vals, ievt );
816 timer.DrawProgressBar( ievt );
817 }
818
819 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
820 << "Elapsed time for evaluation of " << nEvents << " events: "
821 << timer.GetElapsedTime() << " " << Endl;
822
823 // store time used for testing
825 SetTestTime(timer.ElapsedSeconds());
826
827 TString histNamePrefix(GetTestvarName());
828 histNamePrefix += (type==Types::kTraining?"_Train":"_Test");
829
830 resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbinsMVAoutput, fNbinsH );
831 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefix);
832}
833
834////////////////////////////////////////////////////////////////////////////////
835
836void TMVA::MethodBase::NoErrorCalc(Double_t* const err, Double_t* const errUpper) {
837 if (err) *err=-1;
838 if (errUpper) *errUpper=-1;
839}
840
841////////////////////////////////////////////////////////////////////////////////
842
843Double_t TMVA::MethodBase::GetMvaValue( const Event* const ev, Double_t* err, Double_t* errUpper ) {
844 fTmpEvent = ev;
845 Double_t val = GetMvaValue(err, errUpper);
846 fTmpEvent = 0;
847 return val;
848}
849
850////////////////////////////////////////////////////////////////////////////////
851/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
852/// for a quick determination if an event would be selected as signal or background
853
855 return GetMvaValue()*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
856}
857////////////////////////////////////////////////////////////////////////////////
858/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
859/// for a quick determination if an event with this mva output value would be selected as signal or background
860
862 return mvaVal*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
863}
864
865////////////////////////////////////////////////////////////////////////////////
866/// prepare tree branch with the method's discriminating variable
867
869{
870 Data()->SetCurrentType(type);
871
872 ResultsClassification* clRes =
874
875 Long64_t nEvents = Data()->GetNEvents();
876 clRes->Resize( nEvents );
877
878 // use timer
879 Timer timer( nEvents, GetName(), kTRUE );
880 std::vector<Double_t> mvaValues = GetMvaValues(0, nEvents, true);
881
882 // store time used for testing
884 SetTestTime(timer.ElapsedSeconds());
885
886 // load mva values and type to results object
887 for (Int_t ievt = 0; ievt < nEvents; ievt++) {
888 // note we do not need the trasformed event to get the signal/background information
889 // by calling Data()->GetEvent instead of this->GetEvent we access the untransformed one
890 auto ev = Data()->GetEvent(ievt);
891 clRes->SetValue(mvaValues[ievt], ievt, DataInfo().IsSignal(ev));
892 }
893}
894
895////////////////////////////////////////////////////////////////////////////////
896/// get all the MVA values for the events of the current Data type
897std::vector<Double_t> TMVA::MethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
898{
899
900 Long64_t nEvents = Data()->GetNEvents();
901 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
902 if (firstEvt < 0) firstEvt = 0;
903 std::vector<Double_t> values(lastEvt-firstEvt);
904 // log in case of looping on all the events
905 nEvents = values.size();
906
907 // use timer
908 Timer timer( nEvents, GetName(), kTRUE );
909
910 if (logProgress)
911 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
912 << "Evaluation of " << GetMethodName() << " on "
913 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
914 << " sample (" << nEvents << " events)" << Endl;
915
916 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
917 Data()->SetCurrentEvent(ievt);
918 values[ievt] = GetMvaValue();
919
920 // print progress
921 if (logProgress) {
922 Int_t modulo = Int_t(nEvents/100);
923 if (modulo <= 0 ) modulo = 1;
924 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
925 }
926 }
927 if (logProgress) {
928 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
929 << "Elapsed time for evaluation of " << nEvents << " events: "
930 << timer.GetElapsedTime() << " " << Endl;
931 }
932
933 return values;
934}
935
936////////////////////////////////////////////////////////////////////////////////
937/// get all the MVA values for the events of the given Data type
938// (this is used by Method Category and it does not need to be re-implmented by derived classes )
939std::vector<Double_t> TMVA::MethodBase::GetDataMvaValues(DataSet * data, Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
940{
941 fTmpData = data;
942 auto result = GetMvaValues(firstEvt, lastEvt, logProgress);
943 fTmpData = nullptr;
944 return result;
945}
946
947////////////////////////////////////////////////////////////////////////////////
948/// prepare tree branch with the method's discriminating variable
949
951{
952 Data()->SetCurrentType(type);
953
954 ResultsClassification* mvaProb =
955 (ResultsClassification*)Data()->GetResults(TString("prob_")+GetMethodName(), type, Types::kClassification );
956
957 Long64_t nEvents = Data()->GetNEvents();
958
959 // use timer
960 Timer timer( nEvents, GetName(), kTRUE );
961
962 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
963 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
964
965 mvaProb->Resize( nEvents );
966 for (Int_t ievt=0; ievt<nEvents; ievt++) {
967
968 Data()->SetCurrentEvent(ievt);
969 Float_t proba = ((Float_t)GetProba( GetMvaValue(), 0.5 ));
970 if (proba < 0) break;
971 mvaProb->SetValue( proba, ievt, DataInfo().IsSignal( Data()->GetEvent()) );
972
973 // print progress
974 Int_t modulo = Int_t(nEvents/100);
975 if (modulo <= 0 ) modulo = 1;
976 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
977 }
978
979 Log() << kDEBUG <<Form("Dataset[%s] : ",DataInfo().GetName())
980 << "Elapsed time for evaluation of " << nEvents << " events: "
981 << timer.GetElapsedTime() << " " << Endl;
982}
983
984////////////////////////////////////////////////////////////////////////////////
985/// calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
986///
987/// - bias = average deviation
988/// - dev = average absolute deviation
989/// - rms = rms of deviation
990
992 Double_t& dev, Double_t& devT,
993 Double_t& rms, Double_t& rmsT,
994 Double_t& mInf, Double_t& mInfT,
995 Double_t& corr,
997{
998 Types::ETreeType savedType = Data()->GetCurrentType();
999 Data()->SetCurrentType(type);
1000
1001 bias = 0; biasT = 0; dev = 0; devT = 0; rms = 0; rmsT = 0;
1002 Double_t sumw = 0;
1003 Double_t m1 = 0, m2 = 0, s1 = 0, s2 = 0, s12 = 0; // for correlation
1004 const Int_t nevt = GetNEvents();
1005 Float_t* rV = new Float_t[nevt];
1006 Float_t* tV = new Float_t[nevt];
1007 Float_t* wV = new Float_t[nevt];
1008 Float_t xmin = 1e30, xmax = -1e30;
1009 Log() << kINFO << "Calculate regression for all events" << Endl;
1010 Timer timer( nevt, GetName(), kTRUE );
1011 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1012
1013 const Event* ev = Data()->GetEvent(ievt); // NOTE: need untransformed event here !
1014 Float_t t = ev->GetTarget(0);
1015 Float_t w = ev->GetWeight();
1016 Float_t r = GetRegressionValues()[0];
1017 Float_t d = (r-t);
1018
1019 // find min/max
1022
1023 // store for truncated RMS computation
1024 rV[ievt] = r;
1025 tV[ievt] = t;
1026 wV[ievt] = w;
1027
1028 // compute deviation-squared
1029 sumw += w;
1030 bias += w * d;
1031 dev += w * TMath::Abs(d);
1032 rms += w * d * d;
1033
1034 // compute correlation between target and regression estimate
1035 m1 += t*w; s1 += t*t*w;
1036 m2 += r*w; s2 += r*r*w;
1037 s12 += t*r;
1038 // print progress
1039 Long64_t modulo = Long64_t(nevt / 100);
1040 if (ievt % modulo == 0)
1041 timer.DrawProgressBar(ievt);
1042 }
1043 timer.DrawProgressBar(nevt - 1);
1044 Log() << kINFO << "Elapsed time for evaluation of " << nevt << " events: "
1045 << timer.GetElapsedTime() << " " << Endl;
1046
1047 // standard quantities
1048 bias /= sumw;
1049 dev /= sumw;
1050 rms /= sumw;
1051 rms = TMath::Sqrt(rms - bias*bias);
1052
1053 // correlation
1054 m1 /= sumw;
1055 m2 /= sumw;
1056 corr = s12/sumw - m1*m2;
1057 corr /= TMath::Sqrt( (s1/sumw - m1*m1) * (s2/sumw - m2*m2) );
1058
1059 // create histogram required for computation of mutual information
1060 TH2F* hist = new TH2F( "hist", "hist", 150, xmin, xmax, 100, xmin, xmax );
1061 TH2F* histT = new TH2F( "histT", "histT", 150, xmin, xmax, 100, xmin, xmax );
1062
1063 // compute truncated RMS and fill histogram
1064 Double_t devMax = bias + 2*rms;
1065 Double_t devMin = bias - 2*rms;
1066 sumw = 0;
1067 int ic=0;
1068 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1069 Float_t d = (rV[ievt] - tV[ievt]);
1070 hist->Fill( rV[ievt], tV[ievt], wV[ievt] );
1071 if (d >= devMin && d <= devMax) {
1072 sumw += wV[ievt];
1073 biasT += wV[ievt] * d;
1074 devT += wV[ievt] * TMath::Abs(d);
1075 rmsT += wV[ievt] * d * d;
1076 histT->Fill( rV[ievt], tV[ievt], wV[ievt] );
1077 ic++;
1078 }
1079 }
1080 biasT /= sumw;
1081 devT /= sumw;
1082 rmsT /= sumw;
1083 rmsT = TMath::Sqrt(rmsT - biasT*biasT);
1084 mInf = gTools().GetMutualInformation( *hist );
1085 mInfT = gTools().GetMutualInformation( *histT );
1086
1087 delete hist;
1088 delete histT;
1089
1090 delete [] rV;
1091 delete [] tV;
1092 delete [] wV;
1093
1094 Data()->SetCurrentType(savedType);
1095}
1096
1097
1098////////////////////////////////////////////////////////////////////////////////
1099/// test multiclass classification
1100
1102{
1103 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
1104 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in TestMulticlass, exiting."<<Endl;
1105
1106 // GA evaluation of best cut for sig eff * sig pur. Slow, disabled for now.
1107 // Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for test
1108 // data..." << Endl; for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
1109 // resMulticlass->GetBestMultiClassCuts(icls);
1110 // }
1111
1112 // Create histograms for use in TMVA GUI
1113 TString histNamePrefix(GetTestvarName());
1114 TString histNamePrefixTest{histNamePrefix + "_Test"};
1115 TString histNamePrefixTrain{histNamePrefix + "_Train"};
1116
1117 resMulticlass->CreateMulticlassHistos(histNamePrefixTest, fNbinsMVAoutput, fNbinsH);
1118 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTest);
1119
1120 resMulticlass->CreateMulticlassHistos(histNamePrefixTrain, fNbinsMVAoutput, fNbinsH);
1121 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTrain);
1122}
1123
1124
1125////////////////////////////////////////////////////////////////////////////////
1126/// initialization
1127
1129{
1130 Data()->SetCurrentType(Types::kTesting);
1131
1132 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
1133 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
1134
1135 // sanity checks: tree must exist, and theVar must be in tree
1136 if (0==mvaRes && !(GetMethodTypeName().Contains("Cuts"))) {
1137 Log()<<Form("Dataset[%s] : ",DataInfo().GetName()) << "mvaRes " << mvaRes << " GetMethodTypeName " << GetMethodTypeName()
1138 << " contains " << !(GetMethodTypeName().Contains("Cuts")) << Endl;
1139 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<TestInit> Test variable " << GetTestvarName()
1140 << " not found in tree" << Endl;
1141 }
1142
1143 // basic statistics operations are made in base class
1144 gTools().ComputeStat( GetEventCollection(Types::kTesting), mvaRes->GetValueVector(),
1145 fMeanS, fMeanB, fRmsS, fRmsB, fXmin, fXmax, fSignalClass );
1146
1147 // choose reasonable histogram ranges, by removing outliers
1148 Double_t nrms = 10;
1149 fXmin = TMath::Max( TMath::Min( fMeanS - nrms*fRmsS, fMeanB - nrms*fRmsB ), fXmin );
1150 fXmax = TMath::Min( TMath::Max( fMeanS + nrms*fRmsS, fMeanB + nrms*fRmsB ), fXmax );
1151
1152 // determine cut orientation
1153 fCutOrientation = (fMeanS > fMeanB) ? kPositive : kNegative;
1154
1155 // fill 2 types of histograms for the various analyses
1156 // this one is for actual plotting
1157
1158 Double_t sxmax = fXmax+0.00001;
1159
1160 // classifier response distributions for training sample
1161 // MVA plots used for graphics representation (signal)
1162 TString TestvarName;
1163 if(IsSilentFile())
1164 {
1165 TestvarName=Form("[%s]%s",DataInfo().GetName(),GetTestvarName().Data());
1166 }else
1167 {
1168 TestvarName=GetTestvarName();
1169 }
1170 TH1* mva_s = new TH1D( TestvarName + "_S",TestvarName + "_S", fNbinsMVAoutput, fXmin, sxmax );
1171 TH1* mva_b = new TH1D( TestvarName + "_B",TestvarName + "_B", fNbinsMVAoutput, fXmin, sxmax );
1172 mvaRes->Store(mva_s, "MVA_S");
1173 mvaRes->Store(mva_b, "MVA_B");
1174 mva_s->Sumw2();
1175 mva_b->Sumw2();
1176
1177 TH1* proba_s = 0;
1178 TH1* proba_b = 0;
1179 TH1* rarity_s = 0;
1180 TH1* rarity_b = 0;
1181 if (HasMVAPdfs()) {
1182 // P(MVA) plots used for graphics representation
1183 proba_s = new TH1D( TestvarName + "_Proba_S", TestvarName + "_Proba_S", fNbinsMVAoutput, 0.0, 1.0 );
1184 proba_b = new TH1D( TestvarName + "_Proba_B", TestvarName + "_Proba_B", fNbinsMVAoutput, 0.0, 1.0 );
1185 mvaRes->Store(proba_s, "Prob_S");
1186 mvaRes->Store(proba_b, "Prob_B");
1187 proba_s->Sumw2();
1188 proba_b->Sumw2();
1189
1190 // R(MVA) plots used for graphics representation
1191 rarity_s = new TH1D( TestvarName + "_Rarity_S", TestvarName + "_Rarity_S", fNbinsMVAoutput, 0.0, 1.0 );
1192 rarity_b = new TH1D( TestvarName + "_Rarity_B", TestvarName + "_Rarity_B", fNbinsMVAoutput, 0.0, 1.0 );
1193 mvaRes->Store(rarity_s, "Rar_S");
1194 mvaRes->Store(rarity_b, "Rar_B");
1195 rarity_s->Sumw2();
1196 rarity_b->Sumw2();
1197 }
1198
1199 // MVA plots used for efficiency calculations (large number of bins)
1200 TH1* mva_eff_s = new TH1D( TestvarName + "_S_high", TestvarName + "_S_high", fNbinsH, fXmin, sxmax );
1201 TH1* mva_eff_b = new TH1D( TestvarName + "_B_high", TestvarName + "_B_high", fNbinsH, fXmin, sxmax );
1202 mvaRes->Store(mva_eff_s, "MVA_HIGHBIN_S");
1203 mvaRes->Store(mva_eff_b, "MVA_HIGHBIN_B");
1204 mva_eff_s->Sumw2();
1205 mva_eff_b->Sumw2();
1206
1207 // fill the histograms
1208
1209 ResultsClassification* mvaProb = dynamic_cast<ResultsClassification*>
1210 (Data()->GetResults( TString("prob_")+GetMethodName(), Types::kTesting, Types::kMaxAnalysisType ) );
1211
1212 Log() << kHEADER <<Form("[%s] : ",DataInfo().GetName())<< "Loop over test events and fill histograms with classifier response..." << Endl << Endl;
1213 if (mvaProb) Log() << kINFO << "Also filling probability and rarity histograms (on request)..." << Endl;
1214 //std::vector<Bool_t>* mvaResTypes = mvaRes->GetValueVectorTypes();
1215
1216 //LM: this is needed to avoid crashes in ROOCCURVE
1217 if ( mvaRes->GetSize() != GetNEvents() ) {
1218 Log() << kFATAL << TString::Format("Inconsistent result size %lld with number of events %u ", mvaRes->GetSize() , GetNEvents() ) << Endl;
1219 assert(mvaRes->GetSize() == GetNEvents());
1220 }
1221
1222 for (Long64_t ievt=0; ievt<GetNEvents(); ievt++) {
1223
1224 const Event* ev = GetEvent(ievt);
1225 Float_t v = (*mvaRes)[ievt][0];
1226 Float_t w = ev->GetWeight();
1227
1228 if (DataInfo().IsSignal(ev)) {
1229 //mvaResTypes->push_back(kTRUE);
1230 mva_s ->Fill( v, w );
1231 if (mvaProb) {
1232 proba_s->Fill( (*mvaProb)[ievt][0], w );
1233 rarity_s->Fill( GetRarity( v ), w );
1234 }
1235
1236 mva_eff_s ->Fill( v, w );
1237 }
1238 else {
1239 //mvaResTypes->push_back(kFALSE);
1240 mva_b ->Fill( v, w );
1241 if (mvaProb) {
1242 proba_b->Fill( (*mvaProb)[ievt][0], w );
1243 rarity_b->Fill( GetRarity( v ), w );
1244 }
1245 mva_eff_b ->Fill( v, w );
1246 }
1247 }
1248
1249 // uncomment those (and several others if you want unnormalized output
1250 gTools().NormHist( mva_s );
1251 gTools().NormHist( mva_b );
1252 gTools().NormHist( proba_s );
1253 gTools().NormHist( proba_b );
1254 gTools().NormHist( rarity_s );
1255 gTools().NormHist( rarity_b );
1256 gTools().NormHist( mva_eff_s );
1257 gTools().NormHist( mva_eff_b );
1258
1259 // create PDFs from histograms, using default splines, and no additional smoothing
1260 if (fSplS) { delete fSplS; fSplS = 0; }
1261 if (fSplB) { delete fSplB; fSplB = 0; }
1262 fSplS = new PDF( TString(GetName()) + " PDF Sig", mva_s, PDF::kSpline2 );
1263 fSplB = new PDF( TString(GetName()) + " PDF Bkg", mva_b, PDF::kSpline2 );
1264}
1265
1266////////////////////////////////////////////////////////////////////////////////
1267/// general method used in writing the header of the weight files where
1268/// the used variables, variable transformation type etc. is specified
1269
1270void TMVA::MethodBase::WriteStateToStream( std::ostream& tf ) const
1271{
1272 TString prefix = "";
1273 UserGroup_t * userInfo = gSystem->GetUserInfo();
1274
1275 tf << prefix << "#GEN -*-*-*-*-*-*-*-*-*-*-*- general info -*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1276 tf << prefix << "Method : " << GetMethodTypeName() << "::" << GetMethodName() << std::endl;
1277 tf.setf(std::ios::left);
1278 tf << prefix << "TMVA Release : " << std::setw(10) << GetTrainingTMVAVersionString() << " ["
1279 << GetTrainingTMVAVersionCode() << "]" << std::endl;
1280 tf << prefix << "ROOT Release : " << std::setw(10) << GetTrainingROOTVersionString() << " ["
1281 << GetTrainingROOTVersionCode() << "]" << std::endl;
1282 tf << prefix << "Creator : " << userInfo->fUser << std::endl;
1283 tf << prefix << "Date : "; TDatime *d = new TDatime; tf << d->AsString() << std::endl; delete d;
1284 tf << prefix << "Host : " << gSystem->GetBuildNode() << std::endl;
1285 tf << prefix << "Dir : " << gSystem->WorkingDirectory() << std::endl;
1286 tf << prefix << "Training events: " << Data()->GetNTrainingEvents() << std::endl;
1287
1288 TString analysisType(((const_cast<TMVA::MethodBase*>(this)->GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification"));
1289
1290 tf << prefix << "Analysis type : " << "[" << ((GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification") << "]" << std::endl;
1291 tf << prefix << std::endl;
1292
1293 delete userInfo;
1294
1295 // First write all options
1296 tf << prefix << std::endl << prefix << "#OPT -*-*-*-*-*-*-*-*-*-*-*-*- options -*-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1297 WriteOptionsToStream( tf, prefix );
1298 tf << prefix << std::endl;
1299
1300 // Second write variable info
1301 tf << prefix << std::endl << prefix << "#VAR -*-*-*-*-*-*-*-*-*-*-*-* variables *-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1302 WriteVarsToStream( tf, prefix );
1303 tf << prefix << std::endl;
1304}
1305
1306////////////////////////////////////////////////////////////////////////////////
1307/// xml writing
1308
1309void TMVA::MethodBase::AddInfoItem( void* gi, const TString& name, const TString& value) const
1310{
1311 void* it = gTools().AddChild(gi,"Info");
1312 gTools().AddAttr(it,"name", name);
1313 gTools().AddAttr(it,"value", value);
1314}
1315
1316////////////////////////////////////////////////////////////////////////////////
1317
1319 if (analysisType == Types::kRegression) {
1320 AddRegressionOutput( type );
1321 } else if (analysisType == Types::kMulticlass) {
1322 AddMulticlassOutput( type );
1323 } else {
1324 AddClassifierOutput( type );
1325 if (HasMVAPdfs())
1326 AddClassifierOutputProb( type );
1327 }
1328}
1329
1330////////////////////////////////////////////////////////////////////////////////
1331/// general method used in writing the header of the weight files where
1332/// the used variables, variable transformation type etc. is specified
1333
1334void TMVA::MethodBase::WriteStateToXML( void* parent ) const
1335{
1336 if (!parent) return;
1337
1338 UserGroup_t* userInfo = gSystem->GetUserInfo();
1339
1340 void* gi = gTools().AddChild(parent, "GeneralInfo");
1341 AddInfoItem( gi, "TMVA Release", GetTrainingTMVAVersionString() + " [" + gTools().StringFromInt(GetTrainingTMVAVersionCode()) + "]" );
1342 AddInfoItem( gi, "ROOT Release", GetTrainingROOTVersionString() + " [" + gTools().StringFromInt(GetTrainingROOTVersionCode()) + "]");
1343 AddInfoItem( gi, "Creator", userInfo->fUser);
1344 TDatime dt; AddInfoItem( gi, "Date", dt.AsString());
1345 AddInfoItem( gi, "Host", gSystem->GetBuildNode() );
1346 AddInfoItem( gi, "Dir", gSystem->WorkingDirectory());
1347 AddInfoItem( gi, "Training events", gTools().StringFromInt(Data()->GetNTrainingEvents()));
1348 AddInfoItem( gi, "TrainingTime", gTools().StringFromDouble(const_cast<TMVA::MethodBase*>(this)->GetTrainTime()));
1349
1350 Types::EAnalysisType aType = const_cast<TMVA::MethodBase*>(this)->GetAnalysisType();
1351 TString analysisType((aType==Types::kRegression) ? "Regression" :
1352 (aType==Types::kMulticlass ? "Multiclass" : "Classification"));
1353 AddInfoItem( gi, "AnalysisType", analysisType );
1354 delete userInfo;
1355
1356 // write options
1357 AddOptionsXMLTo( parent );
1358
1359 // write variable info
1360 AddVarsXMLTo( parent );
1361
1362 // write spectator info
1363 if (fModelPersistence)
1364 AddSpectatorsXMLTo( parent );
1365
1366 // write class info if in multiclass mode
1367 AddClassesXMLTo(parent);
1368
1369 // write target info if in regression mode
1370 if (DoRegression()) AddTargetsXMLTo(parent);
1371
1372 // write transformations
1373 GetTransformationHandler(false).AddXMLTo( parent );
1374
1375 // write MVA variable distributions
1376 void* pdfs = gTools().AddChild(parent, "MVAPdfs");
1377 if (fMVAPdfS) fMVAPdfS->AddXMLTo(pdfs);
1378 if (fMVAPdfB) fMVAPdfB->AddXMLTo(pdfs);
1379
1380 // write weights
1381 AddWeightsXMLTo( parent );
1382}
1383
1384////////////////////////////////////////////////////////////////////////////////
1385/// write reference MVA distributions (and other information)
1386/// to a ROOT type weight file
1387
1389{
1390 Bool_t addDirStatus = TH1::AddDirectoryStatus();
1391 TH1::AddDirectory( 0 ); // this avoids the binding of the hists in PDF to the current ROOT file
1392 fMVAPdfS = (TMVA::PDF*)rf.Get( "MVA_PDF_Signal" );
1393 fMVAPdfB = (TMVA::PDF*)rf.Get( "MVA_PDF_Background" );
1394
1395 TH1::AddDirectory( addDirStatus );
1396
1397 ReadWeightsFromStream( rf );
1398
1399 SetTestvarName();
1400}
1401
1402////////////////////////////////////////////////////////////////////////////////
1403/// write options and weights to file
1404/// note that each one text file for the main configuration information
1405/// and one ROOT file for ROOT objects are created
1406
1408{
1409 // ---- create the text file
1410 TString tfname( GetWeightFileName() );
1411
1412 // writing xml file
1413 TString xmlfname( tfname ); xmlfname.ReplaceAll( ".txt", ".xml" );
1414 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1415 << "Creating xml weight file: "
1416 << gTools().Color("lightblue") << xmlfname << gTools().Color("reset") << Endl;
1417 void* doc = gTools().xmlengine().NewDoc();
1418 void* rootnode = gTools().AddChild(0,"MethodSetup", "", true);
1419 gTools().xmlengine().DocSetRootElement(doc,rootnode);
1420 gTools().AddAttr(rootnode,"Method", GetMethodTypeName() + "::" + GetMethodName());
1421 WriteStateToXML(rootnode);
1422 gTools().xmlengine().SaveDoc(doc,xmlfname);
1423 gTools().xmlengine().FreeDoc(doc);
1424}
1425
1426////////////////////////////////////////////////////////////////////////////////
1427/// Function to write options and weights to file
1428
1430{
1431 // get the filename
1432
1433 TString tfname(GetWeightFileName());
1434
1435 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1436 << "Reading weight file: "
1437 << gTools().Color("lightblue") << tfname << gTools().Color("reset") << Endl;
1438
1439 if (tfname.EndsWith(".xml") ) {
1440 void* doc = gTools().xmlengine().ParseFile(tfname,gTools().xmlenginebuffersize()); // the default buffer size in TXMLEngine::ParseFile is 100k. Starting with ROOT 5.29 one can set the buffer size, see: http://savannah.cern.ch/bugs/?78864. This might be necessary for large XML files
1441 if (!doc) {
1442 Log() << kFATAL << "Error parsing XML file " << tfname << Endl;
1443 }
1444 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1445 ReadStateFromXML(rootnode);
1446 gTools().xmlengine().FreeDoc(doc);
1447 }
1448 else {
1449 std::filebuf fb;
1450 fb.open(tfname.Data(),std::ios::in);
1451 if (!fb.is_open()) { // file not found --> Error
1452 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ReadStateFromFile> "
1453 << "Unable to open input weight file: " << tfname << Endl;
1454 }
1455 std::istream fin(&fb);
1456 ReadStateFromStream(fin);
1457 fb.close();
1458 }
1459 if (!fTxtWeightsOnly) {
1460 // ---- read the ROOT file
1461 TString rfname( tfname ); rfname.ReplaceAll( ".txt", ".root" );
1462 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Reading root weight file: "
1463 << gTools().Color("lightblue") << rfname << gTools().Color("reset") << Endl;
1464 TFile* rfile = TFile::Open( rfname, "READ" );
1465 ReadStateFromStream( *rfile );
1466 rfile->Close();
1467 }
1468}
1469////////////////////////////////////////////////////////////////////////////////
1470/// for reading from memory
1471
1473 void* doc = gTools().xmlengine().ParseString(xmlstr);
1474 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1475 ReadStateFromXML(rootnode);
1476 gTools().xmlengine().FreeDoc(doc);
1477
1478 return;
1479}
1480
1481////////////////////////////////////////////////////////////////////////////////
1482
1484{
1485
1486 TString fullMethodName;
1487 gTools().ReadAttr( methodNode, "Method", fullMethodName );
1488
1489 fMethodName = fullMethodName(fullMethodName.Index("::")+2,fullMethodName.Length());
1490
1491 // update logger
1492 Log().SetSource( GetName() );
1493 Log() << kDEBUG//<<Form("Dataset[%s] : ",DataInfo().GetName())
1494 << "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1495
1496 // after the method name is read, the testvar can be set
1497 SetTestvarName();
1498
1499 TString nodeName("");
1500 void* ch = gTools().GetChild(methodNode);
1501 while (ch!=0) {
1502 nodeName = TString( gTools().GetName(ch) );
1503
1504 if (nodeName=="GeneralInfo") {
1505 // read analysis type
1506
1507 TString name(""),val("");
1508 void* antypeNode = gTools().GetChild(ch);
1509 while (antypeNode) {
1510 gTools().ReadAttr( antypeNode, "name", name );
1511
1512 if (name == "TrainingTime")
1513 gTools().ReadAttr( antypeNode, "value", fTrainTime );
1514
1515 if (name == "AnalysisType") {
1516 gTools().ReadAttr( antypeNode, "value", val );
1517 val.ToLower();
1518 if (val == "regression" ) SetAnalysisType( Types::kRegression );
1519 else if (val == "classification" ) SetAnalysisType( Types::kClassification );
1520 else if (val == "multiclass" ) SetAnalysisType( Types::kMulticlass );
1521 else Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Analysis type " << val << " is not known." << Endl;
1522 }
1523
1524 if (name == "TMVA Release" || name == "TMVA") {
1525 TString s;
1526 gTools().ReadAttr( antypeNode, "value", s);
1527 fTMVATrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1528 Log() << kDEBUG <<Form("[%s] : ",DataInfo().GetName()) << "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
1529 }
1530
1531 if (name == "ROOT Release" || name == "ROOT") {
1532 TString s;
1533 gTools().ReadAttr( antypeNode, "value", s);
1534 fROOTTrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1535 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
1536 << "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
1537 }
1538 antypeNode = gTools().GetNextChild(antypeNode);
1539 }
1540 }
1541 else if (nodeName=="Options") {
1542 ReadOptionsFromXML(ch);
1543 ParseOptions();
1544
1545 }
1546 else if (nodeName=="Variables") {
1547 ReadVariablesFromXML(ch);
1548 }
1549 else if (nodeName=="Spectators") {
1550 ReadSpectatorsFromXML(ch);
1551 }
1552 else if (nodeName=="Classes") {
1553 if (DataInfo().GetNClasses()==0) ReadClassesFromXML(ch);
1554 }
1555 else if (nodeName=="Targets") {
1556 if (DataInfo().GetNTargets()==0 && DoRegression()) ReadTargetsFromXML(ch);
1557 }
1558 else if (nodeName=="Transformations") {
1559 GetTransformationHandler().ReadFromXML(ch);
1560 }
1561 else if (nodeName=="MVAPdfs") {
1562 TString pdfname;
1563 if (fMVAPdfS) { delete fMVAPdfS; fMVAPdfS=0; }
1564 if (fMVAPdfB) { delete fMVAPdfB; fMVAPdfB=0; }
1565 void* pdfnode = gTools().GetChild(ch);
1566 if (pdfnode) {
1567 gTools().ReadAttr(pdfnode, "Name", pdfname);
1568 fMVAPdfS = new PDF(pdfname);
1569 fMVAPdfS->ReadXML(pdfnode);
1570 pdfnode = gTools().GetNextChild(pdfnode);
1571 gTools().ReadAttr(pdfnode, "Name", pdfname);
1572 fMVAPdfB = new PDF(pdfname);
1573 fMVAPdfB->ReadXML(pdfnode);
1574 }
1575 }
1576 else if (nodeName=="Weights") {
1577 ReadWeightsFromXML(ch);
1578 }
1579 else {
1580 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Unparsed XML node: '" << nodeName << "'" << Endl;
1581 }
1582 ch = gTools().GetNextChild(ch);
1583
1584 }
1585
1586 // update transformation handler
1587 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1588}
1589
1590////////////////////////////////////////////////////////////////////////////////
1591/// read the header from the weight files of the different MVA methods
1592
1594{
1595 char buf[512];
1596
1597 // when reading from stream, we assume the files are produced with TMVA<=397
1598 SetAnalysisType(Types::kClassification);
1599
1600
1601 // first read the method name
1602 GetLine(fin,buf);
1603 while (!TString(buf).BeginsWith("Method")) GetLine(fin,buf);
1604 TString namestr(buf);
1605
1606 TString methodType = namestr(0,namestr.Index("::"));
1607 methodType = methodType(methodType.Last(' '),methodType.Length());
1608 methodType = methodType.Strip(TString::kLeading);
1609
1610 TString methodName = namestr(namestr.Index("::")+2,namestr.Length());
1611 methodName = methodName.Strip(TString::kLeading);
1612 if (methodName == "") methodName = methodType;
1613 fMethodName = methodName;
1614
1615 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1616
1617 // update logger
1618 Log().SetSource( GetName() );
1619
1620 // now the question is whether to read the variables first or the options (well, of course the order
1621 // of writing them needs to agree)
1622 //
1623 // the option "Decorrelation" is needed to decide if the variables we
1624 // read are decorrelated or not
1625 //
1626 // the variables are needed by some methods (TMLP) to build the NN
1627 // which is done in ProcessOptions so for the time being we first Read and Parse the options then
1628 // we read the variables, and then we process the options
1629
1630 // now read all options
1631 GetLine(fin,buf);
1632 while (!TString(buf).BeginsWith("#OPT")) GetLine(fin,buf);
1633 ReadOptionsFromStream(fin);
1634 ParseOptions();
1635
1636 // Now read variable info
1637 fin.getline(buf,512);
1638 while (!TString(buf).BeginsWith("#VAR")) fin.getline(buf,512);
1639 ReadVarsFromStream(fin);
1640
1641 // now we process the options (of the derived class)
1642 ProcessOptions();
1643
1644 if (IsNormalised()) {
1646 GetTransformationHandler().AddTransformation( new VariableNormalizeTransform(DataInfo()), -1 );
1647 norm->BuildTransformationFromVarInfo( DataInfo().GetVariableInfos() );
1648 }
1649 VariableTransformBase *varTrafo(0), *varTrafo2(0);
1650 if ( fVarTransformString == "None") {
1651 if (fUseDecorr)
1652 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1653 } else if ( fVarTransformString == "Decorrelate" ) {
1654 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1655 } else if ( fVarTransformString == "PCA" ) {
1656 varTrafo = GetTransformationHandler().AddTransformation( new VariablePCATransform(DataInfo()), -1 );
1657 } else if ( fVarTransformString == "Uniform" ) {
1658 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo(),"Uniform"), -1 );
1659 } else if ( fVarTransformString == "Gauss" ) {
1660 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1661 } else if ( fVarTransformString == "GaussDecorr" ) {
1662 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1663 varTrafo2 = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1664 } else {
1665 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ProcessOptions> Variable transform '"
1666 << fVarTransformString << "' unknown." << Endl;
1667 }
1668 // Now read decorrelation matrix if available
1669 if (GetTransformationHandler().GetTransformationList().GetSize() > 0) {
1670 fin.getline(buf,512);
1671 while (!TString(buf).BeginsWith("#MAT")) fin.getline(buf,512);
1672 if (varTrafo) {
1673 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1674 varTrafo->ReadTransformationFromStream(fin, trafo );
1675 }
1676 if (varTrafo2) {
1677 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1678 varTrafo2->ReadTransformationFromStream(fin, trafo );
1679 }
1680 }
1681
1682
1683 if (HasMVAPdfs()) {
1684 // Now read the MVA PDFs
1685 fin.getline(buf,512);
1686 while (!TString(buf).BeginsWith("#MVAPDFS")) fin.getline(buf,512);
1687 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
1688 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
1689 fMVAPdfS = new PDF(TString(GetName()) + " MVA PDF Sig");
1690 fMVAPdfB = new PDF(TString(GetName()) + " MVA PDF Bkg");
1691 fMVAPdfS->SetReadingVersion( GetTrainingTMVAVersionCode() );
1692 fMVAPdfB->SetReadingVersion( GetTrainingTMVAVersionCode() );
1693
1694 fin >> *fMVAPdfS;
1695 fin >> *fMVAPdfB;
1696 }
1697
1698 // Now read weights
1699 fin.getline(buf,512);
1700 while (!TString(buf).BeginsWith("#WGT")) fin.getline(buf,512);
1701 fin.getline(buf,512);
1702 ReadWeightsFromStream( fin );;
1703
1704 // update transformation handler
1705 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1706
1707}
1708
1709////////////////////////////////////////////////////////////////////////////////
1710/// write the list of variables (name, min, max) for a given data
1711/// transformation method to the stream
1712
1713void TMVA::MethodBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const
1714{
1715 o << prefix << "NVar " << DataInfo().GetNVariables() << std::endl;
1716 std::vector<VariableInfo>::const_iterator varIt = DataInfo().GetVariableInfos().begin();
1717 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1718 o << prefix << "NSpec " << DataInfo().GetNSpectators() << std::endl;
1719 varIt = DataInfo().GetSpectatorInfos().begin();
1720 for (; varIt!=DataInfo().GetSpectatorInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1721}
1722
1723////////////////////////////////////////////////////////////////////////////////
1724/// Read the variables (name, min, max) for a given data
1725/// transformation method from the stream. In the stream we only
1726/// expect the limits which will be set
1727
1729{
1730 TString dummy;
1731 UInt_t readNVar;
1732 istr >> dummy >> readNVar;
1733
1734 if (readNVar!=DataInfo().GetNVariables()) {
1735 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1736 << " while there are " << readNVar << " variables declared in the file"
1737 << Endl;
1738 }
1739
1740 // we want to make sure all variables are read in the order they are defined
1741 VariableInfo varInfo;
1742 std::vector<VariableInfo>::iterator varIt = DataInfo().GetVariableInfos().begin();
1743 int varIdx = 0;
1744 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt, ++varIdx) {
1745 varInfo.ReadFromStream(istr);
1746 if (varIt->GetExpression() == varInfo.GetExpression()) {
1747 varInfo.SetExternalLink((*varIt).GetExternalLink());
1748 (*varIt) = varInfo;
1749 }
1750 else {
1751 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVarsFromStream>" << Endl;
1752 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1753 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
1754 Log() << kINFO << "the correct working of the method):" << Endl;
1755 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
1756 Log() << kINFO << " var #" << varIdx <<" declared in file : " << varInfo.GetExpression() << Endl;
1757 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1758 }
1759 }
1760}
1761
1762////////////////////////////////////////////////////////////////////////////////
1763/// write variable info to XML
1764
1765void TMVA::MethodBase::AddVarsXMLTo( void* parent ) const
1766{
1767 void* vars = gTools().AddChild(parent, "Variables");
1768 gTools().AddAttr( vars, "NVar", gTools().StringFromInt(DataInfo().GetNVariables()) );
1769
1770 for (UInt_t idx=0; idx<DataInfo().GetVariableInfos().size(); idx++) {
1771 VariableInfo& vi = DataInfo().GetVariableInfos()[idx];
1772 void* var = gTools().AddChild( vars, "Variable" );
1773 gTools().AddAttr( var, "VarIndex", idx );
1774 vi.AddToXML( var );
1775 }
1776}
1777
1778////////////////////////////////////////////////////////////////////////////////
1779/// write spectator info to XML
1780
1782{
1783 void* specs = gTools().AddChild(parent, "Spectators");
1784
1785 UInt_t writeIdx=0;
1786 for (UInt_t idx=0; idx<DataInfo().GetSpectatorInfos().size(); idx++) {
1787
1788 VariableInfo& vi = DataInfo().GetSpectatorInfos()[idx];
1789
1790 // we do not want to write spectators that are category-cuts,
1791 // except if the method is the category method and the spectators belong to it
1792 if (vi.GetVarType()=='C') continue;
1793
1794 void* spec = gTools().AddChild( specs, "Spectator" );
1795 gTools().AddAttr( spec, "SpecIndex", writeIdx++ );
1796 vi.AddToXML( spec );
1797 }
1798 gTools().AddAttr( specs, "NSpec", gTools().StringFromInt(writeIdx) );
1799}
1800
1801////////////////////////////////////////////////////////////////////////////////
1802/// write class info to XML
1803
1804void TMVA::MethodBase::AddClassesXMLTo( void* parent ) const
1805{
1806 UInt_t nClasses=DataInfo().GetNClasses();
1807
1808 void* classes = gTools().AddChild(parent, "Classes");
1809 gTools().AddAttr( classes, "NClass", nClasses );
1810
1811 for (UInt_t iCls=0; iCls<nClasses; ++iCls) {
1812 ClassInfo *classInfo=DataInfo().GetClassInfo (iCls);
1813 TString className =classInfo->GetName();
1814 UInt_t classNumber=classInfo->GetNumber();
1815
1816 void* classNode=gTools().AddChild(classes, "Class");
1817 gTools().AddAttr( classNode, "Name", className );
1818 gTools().AddAttr( classNode, "Index", classNumber );
1819 }
1820}
1821////////////////////////////////////////////////////////////////////////////////
1822/// write target info to XML
1823
1824void TMVA::MethodBase::AddTargetsXMLTo( void* parent ) const
1825{
1826 void* targets = gTools().AddChild(parent, "Targets");
1827 gTools().AddAttr( targets, "NTrgt", gTools().StringFromInt(DataInfo().GetNTargets()) );
1828
1829 for (UInt_t idx=0; idx<DataInfo().GetTargetInfos().size(); idx++) {
1830 VariableInfo& vi = DataInfo().GetTargetInfos()[idx];
1831 void* tar = gTools().AddChild( targets, "Target" );
1832 gTools().AddAttr( tar, "TargetIndex", idx );
1833 vi.AddToXML( tar );
1834 }
1835}
1836
1837////////////////////////////////////////////////////////////////////////////////
1838/// read variable info from XML
1839
1841{
1842 UInt_t readNVar;
1843 gTools().ReadAttr( varnode, "NVar", readNVar);
1844
1845 if (readNVar!=DataInfo().GetNVariables()) {
1846 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1847 << " while there are " << readNVar << " variables declared in the file"
1848 << Endl;
1849 }
1850
1851 // we want to make sure all variables are read in the order they are defined
1852 VariableInfo readVarInfo, existingVarInfo;
1853 int varIdx = 0;
1854 void* ch = gTools().GetChild(varnode);
1855 while (ch) {
1856 gTools().ReadAttr( ch, "VarIndex", varIdx);
1857 existingVarInfo = DataInfo().GetVariableInfos()[varIdx];
1858 readVarInfo.ReadFromXML(ch);
1859
1860 if (existingVarInfo.GetExpression() == readVarInfo.GetExpression()) {
1861 readVarInfo.SetExternalLink(existingVarInfo.GetExternalLink());
1862 existingVarInfo = readVarInfo;
1863 }
1864 else {
1865 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVariablesFromXML>" << Endl;
1866 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1867 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1868 Log() << kINFO << "correct working of the method):" << Endl;
1869 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << existingVarInfo.GetExpression() << Endl;
1870 Log() << kINFO << " var #" << varIdx <<" declared in file : " << readVarInfo.GetExpression() << Endl;
1871 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1872 }
1873 ch = gTools().GetNextChild(ch);
1874 }
1875}
1876
1877////////////////////////////////////////////////////////////////////////////////
1878/// read spectator info from XML
1879
1881{
1882 UInt_t readNSpec;
1883 gTools().ReadAttr( specnode, "NSpec", readNSpec);
1884
1885 if (readNSpec!=DataInfo().GetNSpectators(kFALSE)) {
1886 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "You declared "<< DataInfo().GetNSpectators(kFALSE) << " spectators in the Reader"
1887 << " while there are " << readNSpec << " spectators declared in the file"
1888 << Endl;
1889 }
1890
1891 // we want to make sure all variables are read in the order they are defined
1892 VariableInfo readSpecInfo, existingSpecInfo;
1893 int specIdx = 0;
1894 void* ch = gTools().GetChild(specnode);
1895 while (ch) {
1896 gTools().ReadAttr( ch, "SpecIndex", specIdx);
1897 existingSpecInfo = DataInfo().GetSpectatorInfos()[specIdx];
1898 readSpecInfo.ReadFromXML(ch);
1899
1900 if (existingSpecInfo.GetExpression() == readSpecInfo.GetExpression()) {
1901 readSpecInfo.SetExternalLink(existingSpecInfo.GetExternalLink());
1902 existingSpecInfo = readSpecInfo;
1903 }
1904 else {
1905 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadSpectatorsFromXML>" << Endl;
1906 Log() << kINFO << "The definition (or the order) of the spectators found in the input file is" << Endl;
1907 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1908 Log() << kINFO << "correct working of the method):" << Endl;
1909 Log() << kINFO << " spec #" << specIdx <<" declared in Reader: " << existingSpecInfo.GetExpression() << Endl;
1910 Log() << kINFO << " spec #" << specIdx <<" declared in file : " << readSpecInfo.GetExpression() << Endl;
1911 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1912 }
1913 ch = gTools().GetNextChild(ch);
1914 }
1915}
1916
1917////////////////////////////////////////////////////////////////////////////////
1918/// read number of classes from XML
1919
1921{
1922 UInt_t readNCls;
1923 // coverity[tainted_data_argument]
1924 gTools().ReadAttr( clsnode, "NClass", readNCls);
1925
1926 TString className="";
1927 UInt_t classIndex=0;
1928 void* ch = gTools().GetChild(clsnode);
1929 if (!ch) {
1930 for (UInt_t icls = 0; icls<readNCls;++icls) {
1931 TString classname = Form("class%i",icls);
1932 DataInfo().AddClass(classname);
1933
1934 }
1935 }
1936 else{
1937 while (ch) {
1938 gTools().ReadAttr( ch, "Index", classIndex);
1939 gTools().ReadAttr( ch, "Name", className );
1940 DataInfo().AddClass(className);
1941
1942 ch = gTools().GetNextChild(ch);
1943 }
1944 }
1945
1946 // retrieve signal and background class index
1947 if (DataInfo().GetClassInfo("Signal") != 0) {
1948 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
1949 }
1950 else
1951 fSignalClass=0;
1952 if (DataInfo().GetClassInfo("Background") != 0) {
1953 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
1954 }
1955 else
1956 fBackgroundClass=1;
1957}
1958
1959////////////////////////////////////////////////////////////////////////////////
1960/// read target info from XML
1961
1963{
1964 UInt_t readNTar;
1965 gTools().ReadAttr( tarnode, "NTrgt", readNTar);
1966
1967 int tarIdx = 0;
1968 TString expression;
1969 void* ch = gTools().GetChild(tarnode);
1970 while (ch) {
1971 gTools().ReadAttr( ch, "TargetIndex", tarIdx);
1972 gTools().ReadAttr( ch, "Expression", expression);
1973 DataInfo().AddTarget(expression,"","",0,0);
1974
1975 ch = gTools().GetNextChild(ch);
1976 }
1977}
1978
1979////////////////////////////////////////////////////////////////////////////////
1980/// returns the ROOT directory where info/histograms etc of the
1981/// corresponding MVA method instance are stored
1982
1984{
1985 if (fBaseDir != 0) return fBaseDir;
1986 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodName() << " not set yet --> check if already there.." <<Endl;
1987
1988 if (IsSilentFile()) {
1989 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
1990 << "MethodBase::BaseDir() - No directory exists when running a Method without output file. Enable the "
1991 "output when creating the factory"
1992 << Endl;
1993 }
1994
1995 TDirectory* methodDir = MethodBaseDir();
1996 if (methodDir==0)
1997 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MethodBase::BaseDir() - MethodBaseDir() return a NULL pointer!" << Endl;
1998
1999 TString defaultDir = GetMethodName();
2000 TDirectory *sdir = methodDir->GetDirectory(defaultDir.Data());
2001 if(!sdir)
2002 {
2003 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " does not exist yet--> created it" <<Endl;
2004 sdir = methodDir->mkdir(defaultDir);
2005 sdir->cd();
2006 // write weight file name into target file
2007 if (fModelPersistence) {
2008 TObjString wfilePath( gSystem->WorkingDirectory() );
2009 TObjString wfileName( GetWeightFileName() );
2010 wfilePath.Write( "TrainingPath" );
2011 wfileName.Write( "WeightFileName" );
2012 }
2013 }
2014
2015 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " existed, return it.." <<Endl;
2016 return sdir;
2017}
2018
2019////////////////////////////////////////////////////////////////////////////////
2020/// returns the ROOT directory where all instances of the
2021/// corresponding MVA method are stored
2022
2024{
2025 if (fMethodBaseDir != 0) {
2026 return fMethodBaseDir;
2027 }
2028
2029 const char *datasetName = DataInfo().GetName();
2030
2031 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodTypeName()
2032 << " not set yet --> check if already there.." << Endl;
2033
2034 TDirectory *factoryBaseDir = GetFile();
2035 if (!factoryBaseDir) return nullptr;
2036 fMethodBaseDir = factoryBaseDir->GetDirectory(datasetName);
2037 if (!fMethodBaseDir) {
2038 fMethodBaseDir = factoryBaseDir->mkdir(datasetName, Form("Base directory for dataset %s", datasetName));
2039 if (!fMethodBaseDir) {
2040 Log() << kFATAL << "Can not create dir " << datasetName;
2041 }
2042 }
2043 TString methodTypeDir = Form("Method_%s", GetMethodTypeName().Data());
2044 fMethodBaseDir = fMethodBaseDir->GetDirectory(methodTypeDir.Data());
2045
2046 if (!fMethodBaseDir) {
2047 TDirectory *datasetDir = factoryBaseDir->GetDirectory(datasetName);
2048 TString methodTypeDirHelpStr = Form("Directory for all %s methods", GetMethodTypeName().Data());
2049 fMethodBaseDir = datasetDir->mkdir(methodTypeDir.Data(), methodTypeDirHelpStr);
2050 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodName()
2051 << " does not exist yet--> created it" << Endl;
2052 }
2053
2054 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName)
2055 << "Return from MethodBaseDir() after creating base directory " << Endl;
2056 return fMethodBaseDir;
2057}
2058
2059////////////////////////////////////////////////////////////////////////////////
2060/// set directory of weight file
2061
2063{
2064 fFileDir = fileDir;
2065 gSystem->mkdir( fFileDir, kTRUE );
2066}
2067
2068////////////////////////////////////////////////////////////////////////////////
2069/// set the weight file name (depreciated)
2070
2072{
2073 fWeightFile = theWeightFile;
2074}
2075
2076////////////////////////////////////////////////////////////////////////////////
2077/// retrieve weight file name
2078
2080{
2081 if (fWeightFile!="") return fWeightFile;
2082
2083 // the default consists of
2084 // directory/jobname_methodname_suffix.extension.{root/txt}
2085 TString suffix = "";
2086 TString wFileDir(GetWeightFileDir());
2087 TString wFileName = GetJobName() + "_" + GetMethodName() +
2088 suffix + "." + gConfig().GetIONames().fWeightFileExtension + ".xml";
2089 if (wFileDir.IsNull() ) return wFileName;
2090 // add weight file directory of it is not null
2091 return ( wFileDir + (wFileDir[wFileDir.Length()-1]=='/' ? "" : "/")
2092 + wFileName );
2093}
2094////////////////////////////////////////////////////////////////////////////////
2095/// writes all MVA evaluation histograms to file
2096
2098{
2099 BaseDir()->cd();
2100
2101
2102 // write MVA PDFs to file - if exist
2103 if (0 != fMVAPdfS) {
2104 fMVAPdfS->GetOriginalHist()->Write();
2105 fMVAPdfS->GetSmoothedHist()->Write();
2106 fMVAPdfS->GetPDFHist()->Write();
2107 }
2108 if (0 != fMVAPdfB) {
2109 fMVAPdfB->GetOriginalHist()->Write();
2110 fMVAPdfB->GetSmoothedHist()->Write();
2111 fMVAPdfB->GetPDFHist()->Write();
2112 }
2113
2114 // write result-histograms
2115 Results* results = Data()->GetResults( GetMethodName(), treetype, Types::kMaxAnalysisType );
2116 if (!results)
2117 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<WriteEvaluationHistosToFile> Unknown result: "
2118 << GetMethodName() << (treetype==Types::kTraining?"/kTraining":"/kTesting")
2119 << "/kMaxAnalysisType" << Endl;
2120 results->GetStorage()->Write();
2121 if (treetype==Types::kTesting) {
2122 // skipping plotting of variables if too many (default is 200)
2123 if ((int) DataInfo().GetNVariables()< gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables)
2124 GetTransformationHandler().PlotVariables (GetEventCollection( Types::kTesting ), BaseDir() );
2125 else
2126 Log() << kINFO << TString::Format("Dataset[%s] : ",DataInfo().GetName())
2127 << " variable plots are not produces ! The number of variables is " << DataInfo().GetNVariables()
2128 << " , it is larger than " << gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables << Endl;
2129 }
2130}
2131
2132////////////////////////////////////////////////////////////////////////////////
2133/// write special monitoring histograms to file
2134/// dummy implementation here -----------------
2135
2137{
2138}
2139
2140////////////////////////////////////////////////////////////////////////////////
2141/// reads one line from the input stream
2142/// checks for certain keywords and interprets
2143/// the line if keywords are found
2144
2145Bool_t TMVA::MethodBase::GetLine(std::istream& fin, char* buf )
2146{
2147 fin.getline(buf,512);
2148 TString line(buf);
2149 if (line.BeginsWith("TMVA Release")) {
2150 Ssiz_t start = line.First('[')+1;
2151 Ssiz_t length = line.Index("]",start)-start;
2152 TString code = line(start,length);
2153 std::stringstream s(code.Data());
2154 s >> fTMVATrainingVersion;
2155 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
2156 }
2157 if (line.BeginsWith("ROOT Release")) {
2158 Ssiz_t start = line.First('[')+1;
2159 Ssiz_t length = line.Index("]",start)-start;
2160 TString code = line(start,length);
2161 std::stringstream s(code.Data());
2162 s >> fROOTTrainingVersion;
2163 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
2164 }
2165 if (line.BeginsWith("Analysis type")) {
2166 Ssiz_t start = line.First('[')+1;
2167 Ssiz_t length = line.Index("]",start)-start;
2168 TString code = line(start,length);
2169 std::stringstream s(code.Data());
2170 std::string analysisType;
2171 s >> analysisType;
2172 if (analysisType == "regression" || analysisType == "Regression") SetAnalysisType( Types::kRegression );
2173 else if (analysisType == "classification" || analysisType == "Classification") SetAnalysisType( Types::kClassification );
2174 else if (analysisType == "multiclass" || analysisType == "Multiclass") SetAnalysisType( Types::kMulticlass );
2175 else Log() << kFATAL << "Analysis type " << analysisType << " from weight-file not known!" << std::endl;
2176
2177 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Method was trained for "
2178 << (GetAnalysisType() == Types::kRegression ? "Regression" :
2179 (GetAnalysisType() == Types::kMulticlass ? "Multiclass" : "Classification")) << Endl;
2180 }
2181
2182 return true;
2183}
2184
2185////////////////////////////////////////////////////////////////////////////////
2186/// Create PDFs of the MVA output variables
2187
2189{
2190 Data()->SetCurrentType(Types::kTraining);
2191
2192 // the PDF's are stored as results ONLY if the corresponding "results" are booked,
2193 // otherwise they will be only used 'online'
2194 ResultsClassification * mvaRes = dynamic_cast<ResultsClassification*>
2195 ( Data()->GetResults(GetMethodName(), Types::kTraining, Types::kClassification) );
2196
2197 if (mvaRes==0 || mvaRes->GetSize()==0) {
2198 Log() << kERROR<<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CreateMVAPdfs> No result of classifier testing available" << Endl;
2199 }
2200
2201 Double_t minVal = *std::min_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2202 Double_t maxVal = *std::max_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2203
2204 // create histograms that serve as basis to create the MVA Pdfs
2205 TH1* histMVAPdfS = new TH1D( GetMethodTypeName() + "_tr_S", GetMethodTypeName() + "_tr_S",
2206 fMVAPdfS->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2207 TH1* histMVAPdfB = new TH1D( GetMethodTypeName() + "_tr_B", GetMethodTypeName() + "_tr_B",
2208 fMVAPdfB->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2209
2210
2211 // compute sum of weights properly
2212 histMVAPdfS->Sumw2();
2213 histMVAPdfB->Sumw2();
2214
2215 // fill histograms
2216 for (UInt_t ievt=0; ievt<mvaRes->GetSize(); ievt++) {
2217 Double_t theVal = mvaRes->GetValueVector()->at(ievt);
2218 Double_t theWeight = Data()->GetEvent(ievt)->GetWeight();
2219
2220 if (DataInfo().IsSignal(Data()->GetEvent(ievt))) histMVAPdfS->Fill( theVal, theWeight );
2221 else histMVAPdfB->Fill( theVal, theWeight );
2222 }
2223
2224 gTools().NormHist( histMVAPdfS );
2225 gTools().NormHist( histMVAPdfB );
2226
2227 // momentary hack for ROOT problem
2228 if(!IsSilentFile())
2229 {
2230 histMVAPdfS->Write();
2231 histMVAPdfB->Write();
2232 }
2233 // create PDFs
2234 fMVAPdfS->BuildPDF ( histMVAPdfS );
2235 fMVAPdfB->BuildPDF ( histMVAPdfB );
2236 fMVAPdfS->ValidatePDF( histMVAPdfS );
2237 fMVAPdfB->ValidatePDF( histMVAPdfB );
2238
2239 if (DataInfo().GetNClasses() == 2) { // TODO: this is an ugly hack.. adapt this to new framework
2240 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())
2241 << Form( "<CreateMVAPdfs> Separation from histogram (PDF): %1.3f (%1.3f)",
2242 GetSeparation( histMVAPdfS, histMVAPdfB ), GetSeparation( fMVAPdfS, fMVAPdfB ) )
2243 << Endl;
2244 }
2245
2246 delete histMVAPdfS;
2247 delete histMVAPdfB;
2248}
2249
2251 // the simple one, automatically calculates the mvaVal and uses the
2252 // SAME sig/bkg ratio as given in the training sample (typically 50/50
2253 // .. (NormMode=EqualNumEvents) but can be different)
2254 if (!fMVAPdfS || !fMVAPdfB) {
2255 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<GetProba> MVA PDFs for Signal and Background don't exist yet, we'll create them on demand" << Endl;
2256 CreateMVAPdfs();
2257 }
2258 Double_t sigFraction = DataInfo().GetTrainingSumSignalWeights() / (DataInfo().GetTrainingSumSignalWeights() + DataInfo().GetTrainingSumBackgrWeights() );
2259 Double_t mvaVal = GetMvaValue(ev);
2260
2261 return GetProba(mvaVal,sigFraction);
2262
2263}
2264////////////////////////////////////////////////////////////////////////////////
2265/// compute likelihood ratio
2266
2268{
2269 if (!fMVAPdfS || !fMVAPdfB) {
2270 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetProba> MVA PDFs for Signal and Background don't exist" << Endl;
2271 return -1.0;
2272 }
2273 Double_t p_s = fMVAPdfS->GetVal( mvaVal );
2274 Double_t p_b = fMVAPdfB->GetVal( mvaVal );
2275
2276 Double_t denom = p_s*ap_sig + p_b*(1 - ap_sig);
2277
2278 return (denom > 0) ? (p_s*ap_sig) / denom : -1;
2279}
2280
2281////////////////////////////////////////////////////////////////////////////////
2282/// compute rarity:
2283/// \f[
2284/// R(x) = \int_{[-\infty..x]} { PDF(x') dx' }
2285/// \f]
2286/// where PDF(x) is the PDF of the classifier's signal or background distribution
2287
2289{
2290 if ((reftype == Types::kSignal && !fMVAPdfS) || (reftype == Types::kBackground && !fMVAPdfB)) {
2291 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetRarity> Required MVA PDF for Signal or Background does not exist: "
2292 << "select option \"CreateMVAPdfs\"" << Endl;
2293 return 0.0;
2294 }
2295
2296 PDF* thePdf = ((reftype == Types::kSignal) ? fMVAPdfS : fMVAPdfB);
2297
2298 return thePdf->GetIntegral( thePdf->GetXmin(), mvaVal );
2299}
2300
2301////////////////////////////////////////////////////////////////////////////////
2302/// fill background efficiency (resp. rejection) versus signal efficiency plots
2303/// returns signal efficiency at background efficiency indicated in theString
2304
2306{
2307 Data()->SetCurrentType(type);
2308 Results* results = Data()->GetResults( GetMethodName(), type, Types::kClassification );
2309 std::vector<Float_t>* mvaRes = dynamic_cast<ResultsClassification*>(results)->GetValueVector();
2310
2311 // parse input string for required background efficiency
2312 TList* list = gTools().ParseFormatLine( theString );
2313
2314 // sanity check
2315 Bool_t computeArea = kFALSE;
2316 if (!list || list->GetSize() < 2) computeArea = kTRUE; // the area is computed
2317 else if (list->GetSize() > 2) {
2318 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Wrong number of arguments"
2319 << " in string: " << theString
2320 << " | required format, e.g., Efficiency:0.05, or empty string" << Endl;
2321 delete list;
2322 return -1;
2323 }
2324
2325 // sanity check
2326 if ( results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2327 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2328 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Binning mismatch between signal and background histos" << Endl;
2329 delete list;
2330 return -1.0;
2331 }
2332
2333 // create histograms
2334
2335 // first, get efficiency histograms for signal and background
2336 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2337 Double_t xmin = effhist->GetXaxis()->GetXmin();
2338 Double_t xmax = effhist->GetXaxis()->GetXmax();
2339
2340 TTHREAD_TLS(Double_t) nevtS;
2341
2342 // first round ? --> create histograms
2343 if (results->DoesExist("MVA_EFF_S")==0) {
2344
2345 // for efficiency plot
2346 TH1* eff_s = new TH1D( GetTestvarName() + "_effS", GetTestvarName() + " (signal)", fNbinsH, xmin, xmax );
2347 TH1* eff_b = new TH1D( GetTestvarName() + "_effB", GetTestvarName() + " (background)", fNbinsH, xmin, xmax );
2348 results->Store(eff_s, "MVA_EFF_S");
2349 results->Store(eff_b, "MVA_EFF_B");
2350
2351 // sign if cut
2352 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2353
2354 // this method is unbinned
2355 nevtS = 0;
2356 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2357
2358 // read the tree
2359 Bool_t isSignal = DataInfo().IsSignal(GetEvent(ievt));
2360 Float_t theWeight = GetEvent(ievt)->GetWeight();
2361 Float_t theVal = (*mvaRes)[ievt];
2362
2363 // select histogram depending on if sig or bgd
2364 TH1* theHist = isSignal ? eff_s : eff_b;
2365
2366 // count signal and background events in tree
2367 if (isSignal) nevtS+=theWeight;
2368
2369 TAxis* axis = theHist->GetXaxis();
2370 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2371 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2372 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2373 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2374 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2375
2376 if (sign > 0)
2377 for (Int_t ibin=1; ibin<=maxbin; ibin++) theHist->AddBinContent( ibin , theWeight);
2378 else if (sign < 0)
2379 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theHist->AddBinContent( ibin , theWeight );
2380 else
2381 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Mismatch in sign" << Endl;
2382 }
2383
2384 // renormalise maximum to <=1
2385 // eff_s->Scale( 1.0/TMath::Max(1.,eff_s->GetMaximum()) );
2386 // eff_b->Scale( 1.0/TMath::Max(1.,eff_b->GetMaximum()) );
2387
2390
2391 // background efficiency versus signal efficiency
2392 TH1* eff_BvsS = new TH1D( GetTestvarName() + "_effBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2393 results->Store(eff_BvsS, "MVA_EFF_BvsS");
2394 eff_BvsS->SetXTitle( "Signal eff" );
2395 eff_BvsS->SetYTitle( "Backgr eff" );
2396
2397 // background rejection (=1-eff.) versus signal efficiency
2398 TH1* rej_BvsS = new TH1D( GetTestvarName() + "_rejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2399 results->Store(rej_BvsS);
2400 rej_BvsS->SetXTitle( "Signal eff" );
2401 rej_BvsS->SetYTitle( "Backgr rejection (1-eff)" );
2402
2403 // inverse background eff (1/eff.) versus signal efficiency
2404 TH1* inveff_BvsS = new TH1D( GetTestvarName() + "_invBeffvsSeff",
2405 GetTestvarName(), fNbins, 0, 1 );
2406 results->Store(inveff_BvsS);
2407 inveff_BvsS->SetXTitle( "Signal eff" );
2408 inveff_BvsS->SetYTitle( "Inverse backgr. eff (1/eff)" );
2409
2410 // use root finder
2411 // spline background efficiency plot
2412 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2414 fSplRefS = new TSpline1( "spline2_signal", new TGraph( eff_s ) );
2415 fSplRefB = new TSpline1( "spline2_background", new TGraph( eff_b ) );
2416
2417 // verify spline sanity
2418 gTools().CheckSplines( eff_s, fSplRefS );
2419 gTools().CheckSplines( eff_b, fSplRefB );
2420 }
2421
2422 // make the background-vs-signal efficiency plot
2423
2424 // create root finder
2425 RootFinder rootFinder( this, fXmin, fXmax );
2426
2427 Double_t effB = 0;
2428 fEffS = eff_s; // to be set for the root finder
2429 for (Int_t bini=1; bini<=fNbins; bini++) {
2430
2431 // find cut value corresponding to a given signal efficiency
2432 Double_t effS = eff_BvsS->GetBinCenter( bini );
2433 Double_t cut = rootFinder.Root( effS );
2434
2435 // retrieve background efficiency for given cut
2436 if (Use_Splines_for_Eff_) effB = fSplRefB->Eval( cut );
2437 else effB = eff_b->GetBinContent( eff_b->FindBin( cut ) );
2438
2439 // and fill histograms
2440 eff_BvsS->SetBinContent( bini, effB );
2441 rej_BvsS->SetBinContent( bini, 1.0-effB );
2443 inveff_BvsS->SetBinContent( bini, 1.0/effB );
2444 }
2445
2446 // create splines for histogram
2447 fSpleffBvsS = new TSpline1( "effBvsS", new TGraph( eff_BvsS ) );
2448
2449 // search for overlap point where, when cutting on it,
2450 // one would obtain: eff_S = rej_B = 1 - eff_B
2451 Double_t effS = 0., rejB, effS_ = 0., rejB_ = 0.;
2452 Int_t nbins_ = 5000;
2453 for (Int_t bini=1; bini<=nbins_; bini++) {
2454
2455 // get corresponding signal and background efficiencies
2456 effS = (bini - 0.5)/Float_t(nbins_);
2457 rejB = 1.0 - fSpleffBvsS->Eval( effS );
2458
2459 // find signal efficiency that corresponds to required background efficiency
2460 if ((effS - rejB)*(effS_ - rejB_) < 0) break;
2461 effS_ = effS;
2462 rejB_ = rejB;
2463 }
2464
2465 // find cut that corresponds to signal efficiency and update signal-like criterion
2466 Double_t cut = rootFinder.Root( 0.5*(effS + effS_) );
2467 SetSignalReferenceCut( cut );
2468 fEffS = 0;
2469 }
2470
2471 // must exist...
2472 if (0 == fSpleffBvsS) {
2473 delete list;
2474 return 0.0;
2475 }
2476
2477 // now find signal efficiency that corresponds to required background efficiency
2478 Double_t effS = 0, effB = 0, effS_ = 0, effB_ = 0;
2479 Int_t nbins_ = 1000;
2480
2481 if (computeArea) {
2482
2483 // compute area of rej-vs-eff plot
2484 Double_t integral = 0;
2485 for (Int_t bini=1; bini<=nbins_; bini++) {
2486
2487 // get corresponding signal and background efficiencies
2488 effS = (bini - 0.5)/Float_t(nbins_);
2489 effB = fSpleffBvsS->Eval( effS );
2490 integral += (1.0 - effB);
2491 }
2492 integral /= nbins_;
2493
2494 delete list;
2495 return integral;
2496 }
2497 else {
2498
2499 // that will be the value of the efficiency retured (does not affect
2500 // the efficiency-vs-bkg plot which is done anyway.
2501 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2502
2503 // find precise efficiency value
2504 for (Int_t bini=1; bini<=nbins_; bini++) {
2505
2506 // get corresponding signal and background efficiencies
2507 effS = (bini - 0.5)/Float_t(nbins_);
2508 effB = fSpleffBvsS->Eval( effS );
2509
2510 // find signal efficiency that corresponds to required background efficiency
2511 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2512 effS_ = effS;
2513 effB_ = effB;
2514 }
2515
2516 // take mean between bin above and bin below
2517 effS = 0.5*(effS + effS_);
2518
2519 effSerr = 0;
2520 if (nevtS > 0) effSerr = TMath::Sqrt( effS*(1.0 - effS)/nevtS );
2521
2522 delete list;
2523 return effS;
2524 }
2525
2526 return -1;
2527}
2528
2529////////////////////////////////////////////////////////////////////////////////
2530
2532{
2533 Data()->SetCurrentType(Types::kTraining);
2534
2535 Results* results = Data()->GetResults(GetMethodName(), Types::kTesting, Types::kNoAnalysisType);
2536
2537 // fill background efficiency (resp. rejection) versus signal efficiency plots
2538 // returns signal efficiency at background efficiency indicated in theString
2539
2540 // parse input string for required background efficiency
2541 TList* list = gTools().ParseFormatLine( theString );
2542 // sanity check
2543
2544 if (list->GetSize() != 2) {
2545 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Wrong number of arguments"
2546 << " in string: " << theString
2547 << " | required format, e.g., Efficiency:0.05" << Endl;
2548 delete list;
2549 return -1;
2550 }
2551 // that will be the value of the efficiency retured (does not affect
2552 // the efficiency-vs-bkg plot which is done anyway.
2553 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2554
2555 delete list;
2556
2557 // sanity check
2558 if (results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2559 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2560 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Binning mismatch between signal and background histos"
2561 << Endl;
2562 return -1.0;
2563 }
2564
2565 // create histogram
2566
2567 // first, get efficiency histograms for signal and background
2568 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2569 Double_t xmin = effhist->GetXaxis()->GetXmin();
2570 Double_t xmax = effhist->GetXaxis()->GetXmax();
2571
2572 // first round ? --> create and fill histograms
2573 if (results->DoesExist("MVA_TRAIN_S")==0) {
2574
2575 // classifier response distributions for test sample
2576 Double_t sxmax = fXmax+0.00001;
2577
2578 // MVA plots on the training sample (check for overtraining)
2579 TH1* mva_s_tr = new TH1D( GetTestvarName() + "_Train_S",GetTestvarName() + "_Train_S", fNbinsMVAoutput, fXmin, sxmax );
2580 TH1* mva_b_tr = new TH1D( GetTestvarName() + "_Train_B",GetTestvarName() + "_Train_B", fNbinsMVAoutput, fXmin, sxmax );
2581 results->Store(mva_s_tr, "MVA_TRAIN_S");
2582 results->Store(mva_b_tr, "MVA_TRAIN_B");
2583 mva_s_tr->Sumw2();
2584 mva_b_tr->Sumw2();
2585
2586 // Training efficiency plots
2587 TH1* mva_eff_tr_s = new TH1D( GetTestvarName() + "_trainingEffS", GetTestvarName() + " (signal)",
2588 fNbinsH, xmin, xmax );
2589 TH1* mva_eff_tr_b = new TH1D( GetTestvarName() + "_trainingEffB", GetTestvarName() + " (background)",
2590 fNbinsH, xmin, xmax );
2591 results->Store(mva_eff_tr_s, "MVA_TRAINEFF_S");
2592 results->Store(mva_eff_tr_b, "MVA_TRAINEFF_B");
2593
2594 // sign if cut
2595 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2596
2597 std::vector<Double_t> mvaValues = GetMvaValues(0,Data()->GetNEvents());
2598 assert( (Long64_t) mvaValues.size() == Data()->GetNEvents());
2599
2600 // this method is unbinned
2601 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2602
2603 Data()->SetCurrentEvent(ievt);
2604 const Event* ev = GetEvent();
2605
2606 Double_t theVal = mvaValues[ievt];
2607 Double_t theWeight = ev->GetWeight();
2608
2609 TH1* theEffHist = DataInfo().IsSignal(ev) ? mva_eff_tr_s : mva_eff_tr_b;
2610 TH1* theClsHist = DataInfo().IsSignal(ev) ? mva_s_tr : mva_b_tr;
2611
2612 theClsHist->Fill( theVal, theWeight );
2613
2614 TAxis* axis = theEffHist->GetXaxis();
2615 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2616 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2617 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2618 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2619 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2620
2621 if (sign > 0) for (Int_t ibin=1; ibin<=maxbin; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2622 else for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2623 }
2624
2625 // normalise output distributions
2626 // uncomment those (and several others if you want unnormalized output
2627 gTools().NormHist( mva_s_tr );
2628 gTools().NormHist( mva_b_tr );
2629
2630 // renormalise to maximum
2631 mva_eff_tr_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_s->GetMaximum()) );
2632 mva_eff_tr_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_b->GetMaximum()) );
2633
2634 // Training background efficiency versus signal efficiency
2635 TH1* eff_bvss = new TH1D( GetTestvarName() + "_trainingEffBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2636 // Training background rejection (=1-eff.) versus signal efficiency
2637 TH1* rej_bvss = new TH1D( GetTestvarName() + "_trainingRejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2638 results->Store(eff_bvss, "EFF_BVSS_TR");
2639 results->Store(rej_bvss, "REJ_BVSS_TR");
2640
2641 // use root finder
2642 // spline background efficiency plot
2643 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2645 if (fSplTrainRefS) delete fSplTrainRefS;
2646 if (fSplTrainRefB) delete fSplTrainRefB;
2647 fSplTrainRefS = new TSpline1( "spline2_signal", new TGraph( mva_eff_tr_s ) );
2648 fSplTrainRefB = new TSpline1( "spline2_background", new TGraph( mva_eff_tr_b ) );
2649
2650 // verify spline sanity
2651 gTools().CheckSplines( mva_eff_tr_s, fSplTrainRefS );
2652 gTools().CheckSplines( mva_eff_tr_b, fSplTrainRefB );
2653 }
2654
2655 // make the background-vs-signal efficiency plot
2656
2657 // create root finder
2658 RootFinder rootFinder(this, fXmin, fXmax );
2659
2660 Double_t effB = 0;
2661 fEffS = results->GetHist("MVA_TRAINEFF_S");
2662 for (Int_t bini=1; bini<=fNbins; bini++) {
2663
2664 // find cut value corresponding to a given signal efficiency
2665 Double_t effS = eff_bvss->GetBinCenter( bini );
2666
2667 Double_t cut = rootFinder.Root( effS );
2668
2669 // retrieve background efficiency for given cut
2670 if (Use_Splines_for_Eff_) effB = fSplTrainRefB->Eval( cut );
2671 else effB = mva_eff_tr_b->GetBinContent( mva_eff_tr_b->FindBin( cut ) );
2672
2673 // and fill histograms
2674 eff_bvss->SetBinContent( bini, effB );
2675 rej_bvss->SetBinContent( bini, 1.0-effB );
2676 }
2677 fEffS = 0;
2678
2679 // create splines for histogram
2680 fSplTrainEffBvsS = new TSpline1( "effBvsS", new TGraph( eff_bvss ) );
2681 }
2682
2683 // must exist...
2684 if (0 == fSplTrainEffBvsS) return 0.0;
2685
2686 // now find signal efficiency that corresponds to required background efficiency
2687 Double_t effS = 0., effB, effS_ = 0., effB_ = 0.;
2688 Int_t nbins_ = 1000;
2689 for (Int_t bini=1; bini<=nbins_; bini++) {
2690
2691 // get corresponding signal and background efficiencies
2692 effS = (bini - 0.5)/Float_t(nbins_);
2693 effB = fSplTrainEffBvsS->Eval( effS );
2694
2695 // find signal efficiency that corresponds to required background efficiency
2696 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2697 effS_ = effS;
2698 effB_ = effB;
2699 }
2700
2701 return 0.5*(effS + effS_); // the mean between bin above and bin below
2702}
2703
2704////////////////////////////////////////////////////////////////////////////////
2705
2706std::vector<Float_t> TMVA::MethodBase::GetMulticlassEfficiency(std::vector<std::vector<Float_t> >& purity)
2707{
2708 Data()->SetCurrentType(Types::kTesting);
2709 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
2710 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in GetMulticlassEfficiency, exiting."<<Endl;
2711
2712 purity.push_back(resMulticlass->GetAchievablePur());
2713 return resMulticlass->GetAchievableEff();
2714}
2715
2716////////////////////////////////////////////////////////////////////////////////
2717
2718std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity)
2719{
2720 Data()->SetCurrentType(Types::kTraining);
2721 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTraining, Types::kMulticlass));
2722 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in GetMulticlassTrainingEfficiency, exiting."<<Endl;
2723
2724 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for training data..." << Endl;
2725 for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
2726 resMulticlass->GetBestMultiClassCuts(icls);
2727 }
2728
2729 purity.push_back(resMulticlass->GetAchievablePur());
2730 return resMulticlass->GetAchievableEff();
2731}
2732
2733////////////////////////////////////////////////////////////////////////////////
2734/// Construct a confusion matrix for a multiclass classifier. The confusion
2735/// matrix compares, in turn, each class agaist all other classes in a pair-wise
2736/// fashion. In rows with index \f$ k_r = 0 ... K \f$, \f$ k_r \f$ is
2737/// considered signal for the sake of comparison and for each column
2738/// \f$ k_c = 0 ... K \f$ the corresponding class is considered background.
2739///
2740/// Note that the diagonal elements will be returned as NaN since this will
2741/// compare a class against itself.
2742///
2743/// \see TMVA::ResultsMulticlass::GetConfusionMatrix
2744///
2745/// \param[in] effB The background efficiency for which to evaluate.
2746/// \param[in] type The data set on which to evaluate (training, testing ...).
2747///
2748/// \return A matrix containing signal efficiencies for the given background
2749/// efficiency. The diagonal elements are NaN since this measure is
2750/// meaningless (comparing a class against itself).
2751///
2752
2754{
2755 if (GetAnalysisType() != Types::kMulticlass) {
2756 Log() << kFATAL << "Cannot get confusion matrix for non-multiclass analysis." << std::endl;
2757 return TMatrixD(0, 0);
2758 }
2759
2760 Data()->SetCurrentType(type);
2761 ResultsMulticlass *resMulticlass =
2762 dynamic_cast<ResultsMulticlass *>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
2763
2764 if (resMulticlass == nullptr) {
2765 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
2766 << "unable to create pointer in GetMulticlassEfficiency, exiting." << Endl;
2767 return TMatrixD(0, 0);
2768 }
2769
2770 return resMulticlass->GetConfusionMatrix(effB);
2771}
2772
2773////////////////////////////////////////////////////////////////////////////////
2774/// compute significance of mean difference
2775/// \f[
2776/// significance = \frac{|<S> - <B>|}{\sqrt{RMS_{S2} + RMS_{B2}}}
2777/// \f]
2778
2780{
2781 Double_t rms = sqrt( fRmsS*fRmsS + fRmsB*fRmsB );
2782
2783 return (rms > 0) ? TMath::Abs(fMeanS - fMeanB)/rms : 0;
2784}
2785
2786////////////////////////////////////////////////////////////////////////////////
2787/// compute "separation" defined as
2788/// \f[
2789/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2790/// \f]
2791
2793{
2794 return gTools().GetSeparation( histoS, histoB );
2795}
2796
2797////////////////////////////////////////////////////////////////////////////////
2798/// compute "separation" defined as
2799/// \f[
2800/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2801/// \f]
2802
2804{
2805 // note, if zero pointers given, use internal pdf
2806 // sanity check first
2807 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2808 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2809 if (!pdfS) pdfS = fSplS;
2810 if (!pdfB) pdfB = fSplB;
2811
2812 if (!fSplS || !fSplB) {
2813 Log()<<kDEBUG<<Form("[%s] : ",DataInfo().GetName())<< "could not calculate the separation, distributions"
2814 << " fSplS or fSplB are not yet filled" << Endl;
2815 return 0;
2816 }else{
2817 return gTools().GetSeparation( *pdfS, *pdfB );
2818 }
2819}
2820
2821////////////////////////////////////////////////////////////////////////////////
2822/// calculate the area (integral) under the ROC curve as a
2823/// overall quality measure of the classification
2824
2826{
2827 // note, if zero pointers given, use internal pdf
2828 // sanity check first
2829 if ((!histS && histB) || (histS && !histB))
2830 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetROCIntegral(TH1D*, TH1D*)> Mismatch in hists" << Endl;
2831
2832 if (histS==0 || histB==0) return 0.;
2833
2834 TMVA::PDF *pdfS = new TMVA::PDF( " PDF Sig", histS, TMVA::PDF::kSpline3 );
2835 TMVA::PDF *pdfB = new TMVA::PDF( " PDF Bkg", histB, TMVA::PDF::kSpline3 );
2836
2837
2838 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2839 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2840
2841 Double_t integral = 0;
2842 UInt_t nsteps = 1000;
2843 Double_t step = (xmax-xmin)/Double_t(nsteps);
2844 Double_t cut = xmin;
2845 for (UInt_t i=0; i<nsteps; i++) {
2846 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2847 cut+=step;
2848 }
2849 delete pdfS;
2850 delete pdfB;
2851 return integral*step;
2852}
2853
2854
2855////////////////////////////////////////////////////////////////////////////////
2856/// calculate the area (integral) under the ROC curve as a
2857/// overall quality measure of the classification
2858
2860{
2861 // note, if zero pointers given, use internal pdf
2862 // sanity check first
2863 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2864 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2865 if (!pdfS) pdfS = fSplS;
2866 if (!pdfB) pdfB = fSplB;
2867
2868 if (pdfS==0 || pdfB==0) return 0.;
2869
2870 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2871 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2872
2873 Double_t integral = 0;
2874 UInt_t nsteps = 1000;
2875 Double_t step = (xmax-xmin)/Double_t(nsteps);
2876 Double_t cut = xmin;
2877 for (UInt_t i=0; i<nsteps; i++) {
2878 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2879 cut+=step;
2880 }
2881 return integral*step;
2882}
2883
2884////////////////////////////////////////////////////////////////////////////////
2885/// plot significance, \f$ \frac{S}{\sqrt{S^2 + B^2}} \f$, curve for given number
2886/// of signal and background events; returns cut for maximum significance
2887/// also returned via reference is the maximum significance
2888
2890 Double_t BackgroundEvents,
2891 Double_t& max_significance_value ) const
2892{
2893 Results* results = Data()->GetResults( GetMethodName(), Types::kTesting, Types::kMaxAnalysisType );
2894
2895 Double_t max_significance(0);
2896 Double_t effS(0),effB(0),significance(0);
2897 TH1D *temp_histogram = new TH1D("temp", "temp", fNbinsH, fXmin, fXmax );
2898
2899 if (SignalEvents <= 0 || BackgroundEvents <= 0) {
2900 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetMaximumSignificance> "
2901 << "Number of signal or background events is <= 0 ==> abort"
2902 << Endl;
2903 }
2904
2905 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Using ratio SignalEvents/BackgroundEvents = "
2906 << SignalEvents/BackgroundEvents << Endl;
2907
2908 TH1* eff_s = results->GetHist("MVA_EFF_S");
2909 TH1* eff_b = results->GetHist("MVA_EFF_B");
2910
2911 if ( (eff_s==0) || (eff_b==0) ) {
2912 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Efficiency histograms empty !" << Endl;
2913 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "no maximum cut found, return 0" << Endl;
2914 return 0;
2915 }
2916
2917 for (Int_t bin=1; bin<=fNbinsH; bin++) {
2918 effS = eff_s->GetBinContent( bin );
2919 effB = eff_b->GetBinContent( bin );
2920
2921 // put significance into a histogram
2922 significance = sqrt(SignalEvents)*( effS )/sqrt( effS + ( BackgroundEvents / SignalEvents) * effB );
2923
2924 temp_histogram->SetBinContent(bin,significance);
2925 }
2926
2927 // find maximum in histogram
2928 max_significance = temp_histogram->GetBinCenter( temp_histogram->GetMaximumBin() );
2929 max_significance_value = temp_histogram->GetBinContent( temp_histogram->GetMaximumBin() );
2930
2931 // delete
2932 delete temp_histogram;
2933
2934 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Optimal cut at : " << max_significance << Endl;
2935 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "Maximum significance: " << max_significance_value << Endl;
2936
2937 return max_significance;
2938}
2939
2940////////////////////////////////////////////////////////////////////////////////
2941/// calculates rms,mean, xmin, xmax of the event variable
2942/// this can be either done for the variables as they are or for
2943/// normalised variables (in the range of 0-1) if "norm" is set to kTRUE
2944
2946 Double_t& meanS, Double_t& meanB,
2947 Double_t& rmsS, Double_t& rmsB,
2949{
2950 Types::ETreeType previousTreeType = Data()->GetCurrentType();
2951 Data()->SetCurrentType(treeType);
2952
2953 Long64_t entries = Data()->GetNEvents();
2954
2955 // sanity check
2956 if (entries <=0)
2957 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CalculateEstimator> Wrong tree type: " << treeType << Endl;
2958
2959 // index of the wanted variable
2960 UInt_t varIndex = DataInfo().FindVarIndex( theVarName );
2961
2962 // first fill signal and background in arrays before analysis
2963 xmin = +DBL_MAX;
2964 xmax = -DBL_MAX;
2965 Long64_t nEventsS = -1;
2966 Long64_t nEventsB = -1;
2967
2968 // take into account event weights
2969 meanS = 0;
2970 meanB = 0;
2971 rmsS = 0;
2972 rmsB = 0;
2973 Double_t sumwS = 0, sumwB = 0;
2974
2975 // loop over all training events
2976 for (Int_t ievt = 0; ievt < entries; ievt++) {
2977
2978 const Event* ev = GetEvent(ievt);
2979
2980 Double_t theVar = ev->GetValue(varIndex);
2981 Double_t weight = ev->GetWeight();
2982
2983 if (DataInfo().IsSignal(ev)) {
2984 sumwS += weight;
2985 meanS += weight*theVar;
2986 rmsS += weight*theVar*theVar;
2987 }
2988 else {
2989 sumwB += weight;
2990 meanB += weight*theVar;
2991 rmsB += weight*theVar*theVar;
2992 }
2993 xmin = TMath::Min( xmin, theVar );
2994 xmax = TMath::Max( xmax, theVar );
2995 }
2996 ++nEventsS;
2997 ++nEventsB;
2998
2999 meanS = meanS/sumwS;
3000 meanB = meanB/sumwB;
3001 rmsS = TMath::Sqrt( rmsS/sumwS - meanS*meanS );
3002 rmsB = TMath::Sqrt( rmsB/sumwB - meanB*meanB );
3003
3004 Data()->SetCurrentType(previousTreeType);
3005}
3006
3007////////////////////////////////////////////////////////////////////////////////
3008/// create reader class for method (classification only at present)
3009
3010void TMVA::MethodBase::MakeClass( const TString& theClassFileName ) const
3011{
3012 // the default consists of
3013 TString classFileName = "";
3014 if (theClassFileName == "")
3015 classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class.C";
3016 else
3017 classFileName = theClassFileName;
3018
3019 TString className = TString("Read") + GetMethodName();
3020
3021 TString tfname( classFileName );
3022 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
3023 << "Creating standalone class: "
3024 << gTools().Color("lightblue") << classFileName << gTools().Color("reset") << Endl;
3025
3026 std::ofstream fout( classFileName );
3027 if (!fout.good()) { // file could not be opened --> Error
3028 Log() << kFATAL << "<MakeClass> Unable to open file: " << classFileName << Endl;
3029 }
3030
3031 // now create the class
3032 // preamble
3033 fout << "// Class: " << className << std::endl;
3034 fout << "// Automatically generated by MethodBase::MakeClass" << std::endl << "//" << std::endl;
3035
3036 // print general information and configuration state
3037 fout << std::endl;
3038 fout << "/* configuration options =====================================================" << std::endl << std::endl;
3039 WriteStateToStream( fout );
3040 fout << std::endl;
3041 fout << "============================================================================ */" << std::endl;
3042
3043 // generate the class
3044 fout << "" << std::endl;
3045 fout << "#include <array>" << std::endl;
3046 fout << "#include <vector>" << std::endl;
3047 fout << "#include <cmath>" << std::endl;
3048 fout << "#include <string>" << std::endl;
3049 fout << "#include <iostream>" << std::endl;
3050 fout << "" << std::endl;
3051 // now if the classifier needs to write some additional classes for its response implementation
3052 // this code goes here: (at least the header declarations need to come before the main class
3053 this->MakeClassSpecificHeader( fout, className );
3054
3055 fout << "#ifndef IClassifierReader__def" << std::endl;
3056 fout << "#define IClassifierReader__def" << std::endl;
3057 fout << std::endl;
3058 fout << "class IClassifierReader {" << std::endl;
3059 fout << std::endl;
3060 fout << " public:" << std::endl;
3061 fout << std::endl;
3062 fout << " // constructor" << std::endl;
3063 fout << " IClassifierReader() : fStatusIsClean( true ) {}" << std::endl;
3064 fout << " virtual ~IClassifierReader() {}" << std::endl;
3065 fout << std::endl;
3066 fout << " // return classifier response" << std::endl;
3067 if(GetAnalysisType() == Types::kMulticlass) {
3068 fout << " virtual std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3069 } else {
3070 fout << " virtual double GetMvaValue( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3071 }
3072 fout << std::endl;
3073 fout << " // returns classifier status" << std::endl;
3074 fout << " bool IsStatusClean() const { return fStatusIsClean; }" << std::endl;
3075 fout << std::endl;
3076 fout << " protected:" << std::endl;
3077 fout << std::endl;
3078 fout << " bool fStatusIsClean;" << std::endl;
3079 fout << "};" << std::endl;
3080 fout << std::endl;
3081 fout << "#endif" << std::endl;
3082 fout << std::endl;
3083 fout << "class " << className << " : public IClassifierReader {" << std::endl;
3084 fout << std::endl;
3085 fout << " public:" << std::endl;
3086 fout << std::endl;
3087 fout << " // constructor" << std::endl;
3088 fout << " " << className << "( std::vector<std::string>& theInputVars )" << std::endl;
3089 fout << " : IClassifierReader()," << std::endl;
3090 fout << " fClassName( \"" << className << "\" )," << std::endl;
3091 fout << " fNvars( " << GetNvar() << " )" << std::endl;
3092 fout << " {" << std::endl;
3093 fout << " // the training input variables" << std::endl;
3094 fout << " const char* inputVars[] = { ";
3095 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3096 fout << "\"" << GetOriginalVarName(ivar) << "\"";
3097 if (ivar<GetNvar()-1) fout << ", ";
3098 }
3099 fout << " };" << std::endl;
3100 fout << std::endl;
3101 fout << " // sanity checks" << std::endl;
3102 fout << " if (theInputVars.size() <= 0) {" << std::endl;
3103 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": empty input vector\" << std::endl;" << std::endl;
3104 fout << " fStatusIsClean = false;" << std::endl;
3105 fout << " }" << std::endl;
3106 fout << std::endl;
3107 fout << " if (theInputVars.size() != fNvars) {" << std::endl;
3108 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in number of input values: \"" << std::endl;
3109 fout << " << theInputVars.size() << \" != \" << fNvars << std::endl;" << std::endl;
3110 fout << " fStatusIsClean = false;" << std::endl;
3111 fout << " }" << std::endl;
3112 fout << std::endl;
3113 fout << " // validate input variables" << std::endl;
3114 fout << " for (size_t ivar = 0; ivar < theInputVars.size(); ivar++) {" << std::endl;
3115 fout << " if (theInputVars[ivar] != inputVars[ivar]) {" << std::endl;
3116 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in input variable names\" << std::endl" << std::endl;
3117 fout << " << \" for variable [\" << ivar << \"]: \" << theInputVars[ivar].c_str() << \" != \" << inputVars[ivar] << std::endl;" << std::endl;
3118 fout << " fStatusIsClean = false;" << std::endl;
3119 fout << " }" << std::endl;
3120 fout << " }" << std::endl;
3121 fout << std::endl;
3122 fout << " // initialize min and max vectors (for normalisation)" << std::endl;
3123 for (UInt_t ivar = 0; ivar < GetNvar(); ivar++) {
3124 fout << " fVmin[" << ivar << "] = " << std::setprecision(15) << GetXmin( ivar ) << ";" << std::endl;
3125 fout << " fVmax[" << ivar << "] = " << std::setprecision(15) << GetXmax( ivar ) << ";" << std::endl;
3126 }
3127 fout << std::endl;
3128 fout << " // initialize input variable types" << std::endl;
3129 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3130 fout << " fType[" << ivar << "] = \'" << DataInfo().GetVariableInfo(ivar).GetVarType() << "\';" << std::endl;
3131 }
3132 fout << std::endl;
3133 fout << " // initialize constants" << std::endl;
3134 fout << " Initialize();" << std::endl;
3135 fout << std::endl;
3136 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
3137 fout << " // initialize transformation" << std::endl;
3138 fout << " InitTransform();" << std::endl;
3139 }
3140 fout << " }" << std::endl;
3141 fout << std::endl;
3142 fout << " // destructor" << std::endl;
3143 fout << " virtual ~" << className << "() {" << std::endl;
3144 fout << " Clear(); // method-specific" << std::endl;
3145 fout << " }" << std::endl;
3146 fout << std::endl;
3147 fout << " // the classifier response" << std::endl;
3148 fout << " // \"inputValues\" is a vector of input values in the same order as the" << std::endl;
3149 fout << " // variables given to the constructor" << std::endl;
3150 if(GetAnalysisType() == Types::kMulticlass) {
3151 fout << " std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const override;" << std::endl;
3152 } else {
3153 fout << " double GetMvaValue( const std::vector<double>& inputValues ) const override;" << std::endl;
3154 }
3155 fout << std::endl;
3156 fout << " private:" << std::endl;
3157 fout << std::endl;
3158 fout << " // method-specific destructor" << std::endl;
3159 fout << " void Clear();" << std::endl;
3160 fout << std::endl;
3161 if (GetTransformationHandler().GetTransformationList().GetSize()!=0) {
3162 fout << " // input variable transformation" << std::endl;
3163 GetTransformationHandler().MakeFunction(fout, className,1);
3164 fout << " void InitTransform();" << std::endl;
3165 fout << " void Transform( std::vector<double> & iv, int sigOrBgd ) const;" << std::endl;
3166 fout << std::endl;
3167 }
3168 fout << " // common member variables" << std::endl;
3169 fout << " const char* fClassName;" << std::endl;
3170 fout << std::endl;
3171 fout << " const size_t fNvars;" << std::endl;
3172 fout << " size_t GetNvar() const { return fNvars; }" << std::endl;
3173 fout << " char GetType( int ivar ) const { return fType[ivar]; }" << std::endl;
3174 fout << std::endl;
3175 fout << " // normalisation of input variables" << std::endl;
3176 fout << " double fVmin[" << GetNvar() << "];" << std::endl;
3177 fout << " double fVmax[" << GetNvar() << "];" << std::endl;
3178 fout << " double NormVariable( double x, double xmin, double xmax ) const {" << std::endl;
3179 fout << " // normalise to output range: [-1, 1]" << std::endl;
3180 fout << " return 2*(x - xmin)/(xmax - xmin) - 1.0;" << std::endl;
3181 fout << " }" << std::endl;
3182 fout << std::endl;
3183 fout << " // type of input variable: 'F' or 'I'" << std::endl;
3184 fout << " char fType[" << GetNvar() << "];" << std::endl;
3185 fout << std::endl;
3186 fout << " // initialize internal variables" << std::endl;
3187 fout << " void Initialize();" << std::endl;
3188 if(GetAnalysisType() == Types::kMulticlass) {
3189 fout << " std::vector<double> GetMulticlassValues__( const std::vector<double>& inputValues ) const;" << std::endl;
3190 } else {
3191 fout << " double GetMvaValue__( const std::vector<double>& inputValues ) const;" << std::endl;
3192 }
3193 fout << "" << std::endl;
3194 fout << " // private members (method specific)" << std::endl;
3195
3196 // call the classifier specific output (the classifier must close the class !)
3197 MakeClassSpecific( fout, className );
3198
3199 if(GetAnalysisType() == Types::kMulticlass) {
3200 fout << "inline std::vector<double> " << className << "::GetMulticlassValues( const std::vector<double>& inputValues ) const" << std::endl;
3201 } else {
3202 fout << "inline double " << className << "::GetMvaValue( const std::vector<double>& inputValues ) const" << std::endl;
3203 }
3204 fout << "{" << std::endl;
3205 fout << " // classifier response value" << std::endl;
3206 if(GetAnalysisType() == Types::kMulticlass) {
3207 fout << " std::vector<double> retval;" << std::endl;
3208 } else {
3209 fout << " double retval = 0;" << std::endl;
3210 }
3211 fout << std::endl;
3212 fout << " // classifier response, sanity check first" << std::endl;
3213 fout << " if (!IsStatusClean()) {" << std::endl;
3214 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": cannot return classifier response\"" << std::endl;
3215 fout << " << \" because status is dirty\" << std::endl;" << std::endl;
3216 fout << " }" << std::endl;
3217 fout << " else {" << std::endl;
3218 if (IsNormalised()) {
3219 fout << " // normalise variables" << std::endl;
3220 fout << " std::vector<double> iV;" << std::endl;
3221 fout << " iV.reserve(inputValues.size());" << std::endl;
3222 fout << " int ivar = 0;" << std::endl;
3223 fout << " for (std::vector<double>::const_iterator varIt = inputValues.begin();" << std::endl;
3224 fout << " varIt != inputValues.end(); varIt++, ivar++) {" << std::endl;
3225 fout << " iV.push_back(NormVariable( *varIt, fVmin[ivar], fVmax[ivar] ));" << std::endl;
3226 fout << " }" << std::endl;
3227 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3228 GetMethodType() != Types::kHMatrix) {
3229 fout << " Transform( iV, -1 );" << std::endl;
3230 }
3231
3232 if(GetAnalysisType() == Types::kMulticlass) {
3233 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3234 } else {
3235 fout << " retval = GetMvaValue__( iV );" << std::endl;
3236 }
3237 } else {
3238 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3239 GetMethodType() != Types::kHMatrix) {
3240 fout << " std::vector<double> iV(inputValues);" << std::endl;
3241 fout << " Transform( iV, -1 );" << std::endl;
3242 if(GetAnalysisType() == Types::kMulticlass) {
3243 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3244 } else {
3245 fout << " retval = GetMvaValue__( iV );" << std::endl;
3246 }
3247 } else {
3248 if(GetAnalysisType() == Types::kMulticlass) {
3249 fout << " retval = GetMulticlassValues__( inputValues );" << std::endl;
3250 } else {
3251 fout << " retval = GetMvaValue__( inputValues );" << std::endl;
3252 }
3253 }
3254 }
3255 fout << " }" << std::endl;
3256 fout << std::endl;
3257 fout << " return retval;" << std::endl;
3258 fout << "}" << std::endl;
3259
3260 // create output for transformation - if any
3261 if (GetTransformationHandler().GetTransformationList().GetSize()!=0)
3262 GetTransformationHandler().MakeFunction(fout, className,2);
3263
3264 // close the file
3265 fout.close();
3266}
3267
3268////////////////////////////////////////////////////////////////////////////////
3269/// prints out method-specific help method
3270
3272{
3273 // if options are written to reference file, also append help info
3274 std::streambuf* cout_sbuf = std::cout.rdbuf(); // save original sbuf
3275 std::ofstream* o = 0;
3276 if (gConfig().WriteOptionsReference()) {
3277 Log() << kINFO << "Print Help message for class " << GetName() << " into file: " << GetReferenceFile() << Endl;
3278 o = new std::ofstream( GetReferenceFile(), std::ios::app );
3279 if (!o->good()) { // file could not be opened --> Error
3280 Log() << kFATAL << "<PrintHelpMessage> Unable to append to output file: " << GetReferenceFile() << Endl;
3281 }
3282 std::cout.rdbuf( o->rdbuf() ); // redirect 'std::cout' to file
3283 }
3284
3285 // "|--------------------------------------------------------------|"
3286 if (!o) {
3287 Log() << kINFO << Endl;
3288 Log() << gTools().Color("bold")
3289 << "================================================================"
3290 << gTools().Color( "reset" )
3291 << Endl;
3292 Log() << gTools().Color("bold")
3293 << "H e l p f o r M V A m e t h o d [ " << GetName() << " ] :"
3294 << gTools().Color( "reset" )
3295 << Endl;
3296 }
3297 else {
3298 Log() << "Help for MVA method [ " << GetName() << " ] :" << Endl;
3299 }
3300
3301 // print method-specific help message
3302 GetHelpMessage();
3303
3304 if (!o) {
3305 Log() << Endl;
3306 Log() << "<Suppress this message by specifying \"!H\" in the booking option>" << Endl;
3307 Log() << gTools().Color("bold")
3308 << "================================================================"
3309 << gTools().Color( "reset" )
3310 << Endl;
3311 Log() << Endl;
3312 }
3313 else {
3314 // indicate END
3315 Log() << "# End of Message___" << Endl;
3316 }
3317
3318 std::cout.rdbuf( cout_sbuf ); // restore the original stream buffer
3319 if (o) o->close();
3320}
3321
3322// ----------------------- r o o t f i n d i n g ----------------------------
3323
3324////////////////////////////////////////////////////////////////////////////////
3325/// returns efficiency as function of cut
3326
3328{
3329 Double_t retval=0;
3330
3331 // retrieve the class object
3333 retval = fSplRefS->Eval( theCut );
3334 }
3335 else retval = fEffS->GetBinContent( fEffS->FindBin( theCut ) );
3336
3337 // caution: here we take some "forbidden" action to hide a problem:
3338 // in some cases, in particular for likelihood, the binned efficiency distributions
3339 // do not equal 1, at xmin, and 0 at xmax; of course, in principle we have the
3340 // unbinned information available in the trees, but the unbinned minimization is
3341 // too slow, and we don't need to do a precision measurement here. Hence, we force
3342 // this property.
3343 Double_t eps = 1.0e-5;
3344 if (theCut-fXmin < eps) retval = (GetCutOrientation() == kPositive) ? 1.0 : 0.0;
3345 else if (fXmax-theCut < eps) retval = (GetCutOrientation() == kPositive) ? 0.0 : 1.0;
3346
3347 return retval;
3348}
3349
3350////////////////////////////////////////////////////////////////////////////////
3351/// returns the event collection (i.e. the dataset) TRANSFORMED using the
3352/// classifiers specific Variable Transformation (e.g. Decorr or Decorr:Gauss:Decorr)
3353
3355{
3356 // if there's no variable transformation for this classifier, just hand back the
3357 // event collection of the data set
3358 if (GetTransformationHandler().GetTransformationList().GetEntries() <= 0) {
3359 return (Data()->GetEventCollection(type));
3360 }
3361
3362 // otherwise, transform ALL the events and hand back the vector of the pointers to the
3363 // transformed events. If the pointer is already != 0, i.e. the whole thing has been
3364 // done before, I don't need to do it again, but just "hand over" the pointer to those events.
3365 Int_t idx = Data()->TreeIndex(type); //index indicating Training,Testing,... events/datasets
3366 if (fEventCollections.at(idx) == 0) {
3367 fEventCollections.at(idx) = &(Data()->GetEventCollection(type));
3368 fEventCollections.at(idx) = GetTransformationHandler().CalcTransformations(*(fEventCollections.at(idx)),kTRUE);
3369 }
3370 return *(fEventCollections.at(idx));
3371}
3372
3373////////////////////////////////////////////////////////////////////////////////
3374/// calculates the TMVA version string from the training version code on the fly
3375
3377{
3378 UInt_t a = GetTrainingTMVAVersionCode() & 0xff0000; a>>=16;
3379 UInt_t b = GetTrainingTMVAVersionCode() & 0x00ff00; b>>=8;
3380 UInt_t c = GetTrainingTMVAVersionCode() & 0x0000ff;
3381
3382 return TString(Form("%i.%i.%i",a,b,c));
3383}
3384
3385////////////////////////////////////////////////////////////////////////////////
3386/// calculates the ROOT version string from the training version code on the fly
3387
3389{
3390 UInt_t a = GetTrainingROOTVersionCode() & 0xff0000; a>>=16;
3391 UInt_t b = GetTrainingROOTVersionCode() & 0x00ff00; b>>=8;
3392 UInt_t c = GetTrainingROOTVersionCode() & 0x0000ff;
3393
3394 return TString(Form("%i.%02i/%02i",a,b,c));
3395}
3396
3397////////////////////////////////////////////////////////////////////////////////
3398
3400 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
3401 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
3402
3403 if (mvaRes != NULL) {
3404 TH1D *mva_s = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_S"));
3405 TH1D *mva_b = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_B"));
3406 TH1D *mva_s_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_S"));
3407 TH1D *mva_b_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_B"));
3408
3409 if ( !mva_s || !mva_b || !mva_s_tr || !mva_b_tr) return -1;
3410
3411 if (SorB == 's' || SorB == 'S')
3412 return mva_s->KolmogorovTest( mva_s_tr, opt.Data() );
3413 else
3414 return mva_b->KolmogorovTest( mva_b_tr, opt.Data() );
3415 }
3416 return -1;
3417}
const Bool_t Use_Splines_for_Eff_
Definition: MethodBase.cxx:130
const Int_t NBIN_HIST_HIGH
Definition: MethodBase.cxx:133
ROOT::R::TRInterface & r
Definition: Object.C:4
#define d(i)
Definition: RSha256.hxx:102
#define b(i)
Definition: RSha256.hxx:100
#define c(i)
Definition: RSha256.hxx:101
#define s1(x)
Definition: RSha256.hxx:91
#define ROOT_VERSION_CODE
Definition: RVersion.h:21
int Int_t
Definition: RtypesCore.h:45
int Ssiz_t
Definition: RtypesCore.h:67
char Char_t
Definition: RtypesCore.h:33
unsigned int UInt_t
Definition: RtypesCore.h:46
const Bool_t kFALSE
Definition: RtypesCore.h:101
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
long long Long64_t
Definition: RtypesCore.h:80
float Float_t
Definition: RtypesCore.h:57
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassImp(name)
Definition: Rtypes.h:364
char name[80]
Definition: TGX11.cxx:110
int type
Definition: TGX11.cxx:121
float xmin
Definition: THbookFile.cxx:95
float xmax
Definition: THbookFile.cxx:95
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:559
#define TMVA_VERSION_CODE
Definition: Version.h:47
Class to manage histogram axis.
Definition: TAxis.h:30
Double_t GetXmax() const
Definition: TAxis.h:134
Double_t GetXmin() const
Definition: TAxis.h:133
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:184
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write all objects in this collection.
This class stores the date and time with a precision of one second in an unsigned 32 bit word (950130...
Definition: TDatime.h:37
const char * AsString() const
Return the date & time as a string (ctime() format).
Definition: TDatime.cxx:102
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition: TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:407
virtual TDirectory * mkdir(const char *name, const char *title="", Bool_t returnExistingDirectory=kFALSE)
Create a sub-directory "a" or a hierarchy of sub-directories "a/b/c/...".
virtual Bool_t cd(const char *path=nullptr)
Change current directory to "this" directory.
Definition: TDirectory.cxx:504
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:4011
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:889
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
1-D histogram with a double per channel (see TH1 documentation)}
Definition: TH1.h:618
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:575
TH1 is the base class of all histogram classes in ROOT.
Definition: TH1.h:58
virtual Double_t GetBinCenter(Int_t bin) const
Return bin center for 1D histogram.
Definition: TH1.cxx:8984
virtual Int_t GetQuantiles(Int_t nprobSum, Double_t *q, const Double_t *probSum=0)
Compute Quantiles for this histogram Quantile x_q of a probability distribution Function F is defined...
Definition: TH1.cxx:4544
virtual void AddBinContent(Int_t bin)
Increment bin content by 1.
Definition: TH1.cxx:1258
virtual Double_t GetMean(Int_t axis=1) const
For axis = 1,2 or 3 returns the mean value of the histogram along X,Y or Z axis.
Definition: TH1.cxx:7423
virtual void SetXTitle(const char *title)
Definition: TH1.h:413
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1283
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:320
virtual Double_t GetMaximum(Double_t maxval=FLT_MAX) const
Return maximum value smaller than maxval of bins in the range, unless the value has been overridden b...
Definition: TH1.cxx:8388
virtual Int_t GetNbinsX() const
Definition: TH1.h:296
virtual Int_t Fill(Double_t x)
Increment bin with abscissa X by 1.
Definition: TH1.cxx:3351
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:9065
virtual Int_t GetMaximumBin() const
Return location of bin with maximum value in the range.
Definition: TH1.cxx:8420
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition: TH1.cxx:4994
virtual void SetYTitle(const char *title)
Definition: TH1.h:414
virtual void Scale(Double_t c1=1, Option_t *option="")
Multiply this histogram by a constant c1.
Definition: TH1.cxx:6566
virtual Int_t FindBin(Double_t x, Double_t y=0, Double_t z=0)
Return Global bin number corresponding to x,y,z.
Definition: TH1.cxx:3681
virtual Double_t KolmogorovTest(const TH1 *h2, Option_t *option="") const
Statistical test of compatibility in shape between this histogram and h2, using Kolmogorov test.
Definition: TH1.cxx:8063
virtual void Sumw2(Bool_t flag=kTRUE)
Create structure to store sum of squares of weights.
Definition: TH1.cxx:8863
static Bool_t AddDirectoryStatus()
Static function: cannot be inlined on Windows/NT.
Definition: TH1.cxx:751
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:251
Int_t Fill(Double_t)
Invalid Fill method.
Definition: TH2.cxx:358
A doubly linked list.
Definition: TList.h:44
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:357
Class that contains all the information of a class.
Definition: ClassInfo.h:49
UInt_t GetNumber() const
Definition: ClassInfo.h:65
TString fWeightFileExtension
Definition: Config.h:125
VariablePlotting & GetVariablePlotting()
Definition: Config.h:97
class TMVA::Config::VariablePlotting fVariablePlotting
IONames & GetIONames()
Definition: Config.h:98
MsgLogger * fLogger
Definition: Configurable.h:128
Class that contains all the data information.
Definition: DataSetInfo.h:62
Class that contains all the data information.
Definition: DataSet.h:58
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:391
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
static void SetIgnoreNegWeightsInTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:400
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
Definition: MethodBase.cxx:169
IPythonInteractive()
standard constructor
Definition: MethodBase.cxx:146
~IPythonInteractive()
standard destructor
Definition: MethodBase.cxx:154
void ClearGraphs()
This function sets the point number to 0 for all graphs.
Definition: MethodBase.cxx:193
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
Definition: MethodBase.cxx:207
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TDirectory * MethodBaseDir() const
returns the ROOT directory where all instances of the corresponding MVA method are stored
virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X")
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
Definition: MethodBase.cxx:237
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void ReadClassesFromXML(void *clsnode)
read number of classes from XML
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
void DeclareBaseOptions()
define the options (their key words) that can be set in the option string here the options valid for ...
Definition: MethodBase.cxx:509
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:991
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:596
virtual Double_t GetSignificance() const
compute significance of mean difference
virtual Double_t GetProba(const Event *ev)
const char * GetName() const
Definition: MethodBase.h:334
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
const std::vector< TMVA::Event * > & GetEventCollection(Types::ETreeType type)
returns the event collection (i.e.
virtual std::vector< Double_t > GetDataMvaValues(DataSet *data=nullptr, Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the given Data type
Definition: MethodBase.cxx:939
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:406
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
void AddInfoItem(void *gi, const TString &name, const TString &value) const
xml writing
virtual void AddClassifierOutputProb(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:950
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
TString GetTrainingTMVAVersionString() const
calculates the TMVA version string from the training version code on the fly
void Statistics(Types::ETreeType treeType, const TString &theVarName, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &)
calculates rms,mean, xmin, xmax of the event variable this can be either done for the variables as th...
Bool_t GetLine(std::istream &fin, char *buf)
reads one line from the input stream checks for certain keywords and interprets the line if keywords ...
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:423
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodBase.cxx:897
virtual Bool_t IsSignalLike()
uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation) for...
Definition: MethodBase.cxx:854
virtual ~MethodBase()
destructor
Definition: MethodBase.cxx:364
virtual Double_t GetMaximumSignificance(Double_t SignalEvents, Double_t BackgroundEvents, Double_t &optimal_significance_value) const
plot significance, , curve for given number of signal and background events; returns cut for maximum ...
virtual Double_t GetTrainingEfficiency(const TString &)
void SetWeightFileName(TString)
set the weight file name (depreciated)
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
TString GetWeightFileName() const
retrieve weight file name
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
virtual void WriteMonitoringHistosToFile() const
write special monitoring histograms to file dummy implementation here --------------—
virtual void AddRegressionOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:744
void InitBase()
default initialization called by all constructors
Definition: MethodBase.cxx:441
virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t &stddev, Double_t &stddev90Percent) const
Definition: MethodBase.cxx:724
void ReadStateFromXMLString(const char *xmlstr)
for reading from memory
void CreateMVAPdfs()
Create PDFs of the MVA output variables.
TString GetTrainingROOTVersionString() const
calculates the ROOT version string from the training version code on the fly
virtual Double_t GetValueForRoot(Double_t)
returns efficiency as function of cut
void ReadStateFromFile()
Function to write options and weights to file.
void WriteVarsToStream(std::ostream &tf, const TString &prefix="") const
write the list of variables (name, min, max) for a given data transformation method to the stream
void ReadVarsFromStream(std::istream &istr)
Read the variables (name, min, max) for a given data transformation method from the stream.
void ReadSpectatorsFromXML(void *specnode)
read spectator info from XML
virtual Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)=0
void SetTestvarName(const TString &v="")
Definition: MethodBase.h:341
void ReadVariablesFromXML(void *varnode)
read variable info from XML
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:623
virtual std::vector< Float_t > GetMulticlassTrainingEfficiency(std::vector< std::vector< Float_t > > &purity)
void WriteStateToStream(std::ostream &tf) const
general method used in writing the header of the weight files where the used variables,...
virtual Double_t GetRarity(Double_t mvaVal, Types::ESBType reftype=Types::kBackground) const
compute rarity:
virtual void SetTuneParameters(std::map< TString, Double_t > tuneParameters)
set the tuning parameters according to the argument This is just a dummy .
Definition: MethodBase.cxx:644
void ReadStateFromStream(std::istream &tf)
read the header from the weight files of the different MVA methods
void AddVarsXMLTo(void *parent) const
write variable info to XML
void AddTargetsXMLTo(void *parent) const
write target info to XML
void ReadTargetsFromXML(void *tarnode)
read target info from XML
void ProcessBaseOptions()
the option string is decoded, for available options see "DeclareOptions"
Definition: MethodBase.cxx:540
void ReadStateFromXML(void *parent)
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:836
void WriteStateToFile() const
write options and weights to file note that each one text file for the main configuration information...
void AddClassesXMLTo(void *parent) const
write class info to XML
virtual void AddClassifierOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:868
void AddSpectatorsXMLTo(void *parent) const
write spectator info to XML
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void AddMulticlassOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:794
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:433
void SetSource(const std::string &source)
Definition: MsgLogger.h:68
PDF wrapper for histograms; uses user-defined spline interpolation.
Definition: PDF.h:63
Double_t GetXmin() const
Definition: PDF.h:104
Double_t GetXmax() const
Definition: PDF.h:105
Double_t GetVal(Double_t x) const
returns value PDF(x)
Definition: PDF.cxx:701
@ kSpline3
Definition: PDF.h:70
@ kSpline2
Definition: PDF.h:70
Double_t GetIntegral(Double_t xmin, Double_t xmax)
computes PDF integral within given ranges
Definition: PDF.cxx:654
Class that is the base-class for a vector of result.
std::vector< Float_t > * GetValueVector()
void SetValue(Float_t value, Int_t ievt, Bool_t type)
set MVA response
Class which takes the results of a multiclass classification.
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
Float_t GetAchievablePur(UInt_t cls)
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
Float_t GetAchievableEff(UInt_t cls)
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
Class that is the base-class for a vector of result.
Class that is the base-class for a vector of result.
Definition: Results.h:57
Bool_t DoesExist(const TString &alias) const
Returns true if there is an object stored in the result for a given alias, false otherwise.
Definition: Results.cxx:127
TH1 * GetHist(const TString &alias) const
Definition: Results.cxx:136
TList * GetStorage() const
Definition: Results.h:73
void Store(TObject *obj, const char *alias=0)
Definition: Results.cxx:86
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
Definition: RootFinder.h:48
Double_t Root(Double_t refValue)
Root finding using Brents algorithm; taken from CERNLIB function RZERO.
Definition: RootFinder.cxx:71
Linear interpolation of TGraph.
Definition: TSpline1.h:43
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
Double_t ElapsedSeconds(void)
computes elapsed tim in seconds
Definition: Timer.cxx:137
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:146
void DrawProgressBar(Int_t, const TString &comment="")
draws progress bar in color or B&W caution:
Definition: Timer.cxx:202
void ComputeStat(const std::vector< TMVA::Event * > &, std::vector< Float_t > *, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Int_t signalClass, Bool_t norm=kFALSE)
sanity check
Definition: Tools.cxx:202
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition: Tools.cxx:401
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1162
Double_t GetSeparation(TH1 *S, TH1 *B) const
compute "separation" defined as
Definition: Tools.cxx:121
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1124
Double_t GetMutualInformation(const TH2F &)
Mutual Information method for non-linear correlations estimates in 2D histogram Author: Moritz Backes...
Definition: Tools.cxx:589
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:828
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1150
TXMLEngine & xmlengine()
Definition: Tools.h:262
Bool_t CheckSplines(const TH1 *, const TSpline *)
check quality of splining by comparing splines and histograms in each bin
Definition: Tools.cxx:479
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:329
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:347
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:383
Singleton class for Global types used by TMVA.
Definition: Types.h:71
@ kSignal
Definition: Types.h:135
@ kBackground
Definition: Types.h:136
@ kLikelihood
Definition: Types.h:79
@ kHMatrix
Definition: Types.h:81
EAnalysisType
Definition: Types.h:126
@ kMulticlass
Definition: Types.h:129
@ kNoAnalysisType
Definition: Types.h:130
@ kClassification
Definition: Types.h:127
@ kMaxAnalysisType
Definition: Types.h:131
@ kRegression
Definition: Types.h:128
@ kTraining
Definition: Types.h:143
@ kTesting
Definition: Types.h:144
Linear interpolation class.
Gaussian Transformation of input variables.
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
void ReadFromXML(void *varnode)
read VariableInfo from stream
const TString & GetExpression() const
Definition: VariableInfo.h:57
char GetVarType() const
Definition: VariableInfo.h:61
void ReadFromStream(std::istream &istr)
read VariableInfo from stream
void AddToXML(void *varnode)
write class to XML
void SetExternalLink(void *p)
Definition: VariableInfo.h:75
void * GetExternalLink() const
Definition: VariableInfo.h:83
void BuildTransformationFromVarInfo(const std::vector< TMVA::VariableInfo > &var)
this method is only used when building a normalization transformation from old text files in this cas...
Linear interpolation class.
Linear interpolation class.
virtual void ReadTransformationFromStream(std::istream &istr, const TString &classname="")=0
@ kDEBUG
Definition: Types.h:56
@ kVERBOSE
Definition: Types.h:57
@ kHEADER
Definition: Types.h:63
@ kERROR
Definition: Types.h:60
@ kINFO
Definition: Types.h:58
@ kWARNING
Definition: Types.h:59
@ kFATAL
Definition: Types.h:61
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:36
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Collectable string class.
Definition: TObjString.h:28
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TObject.cxx:798
Basic string class.
Definition: TString.h:136
Ssiz_t Length() const
Definition: TString.h:410
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1150
Int_t Atoi() const
Return integer value of string.
Definition: TString.cxx:1946
Bool_t EndsWith(const char *pat, ECaseCompare cmp=kExact) const
Return true if string ends with the specified string.
Definition: TString.cxx:2202
TSubString Strip(EStripType s=kTrailing, char c=' ') const
Return a substring of self stripped at beginning and/or end.
Definition: TString.cxx:1131
const char * Data() const
Definition: TString.h:369
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:692
@ kLeading
Definition: TString.h:267
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:916
Bool_t IsNull() const
Definition: TString.h:407
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2336
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:639
virtual const char * GetBuildNode() const
Return the build node name.
Definition: TSystem.cxx:3894
virtual int mkdir(const char *name, Bool_t recursive=kFALSE)
Make a file system directory.
Definition: TSystem.cxx:907
virtual const char * WorkingDirectory()
Return working directory.
Definition: TSystem.cxx:872
virtual UserGroup_t * GetUserInfo(Int_t uid)
Returns all user info in the UserGroup_t structure.
Definition: TSystem.cxx:1599
void SaveDoc(XMLDocPointer_t xmldoc, const char *filename, Int_t layout=1)
store document content to file if layout<=0, no any spaces or newlines will be placed between xmlnode...
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
XMLDocPointer_t NewDoc(const char *version="1.0")
creates new xml document with provided version
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLDocPointer_t ParseString(const char *xmlstring)
parses content of string and tries to produce xml structures
void DocSetRootElement(XMLDocPointer_t xmldoc, XMLNodePointer_t xmlnode)
set main (root) node for document
TLine * line
Double_t x[n]
Definition: legend1.C:17
TH1F * h1
Definition: legend1.C:5
VecExpr< UnaryOp< Sqrt< T >, VecExpr< A, T, D >, T >, T, D > sqrt(const VecExpr< A, T, D > &rhs)
bool BeginsWith(const std::string &theString, const std::string &theSubstring)
void Init(TClassEdit::TInterpreterLookupHelper *helper)
Definition: TClassEdit.cxx:154
static constexpr double s
static constexpr double m2
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:342
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:148
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:208
Double_t Log(Double_t x)
Definition: TMath.h:710
Double_t Sqrt(Double_t x)
Definition: TMath.h:641
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:176
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
TString fUser
Definition: TSystem.h:141
auto * a
Definition: textangle.C:12
REAL epsilon
Definition: triangle.c:618