Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodBase *
8 * *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
16 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * U. of Victoria, Canada *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (see tmva/doc/LICENSE) *
31 * *
32 **********************************************************************************/
33
34/*! \class TMVA::MethodBase
35\ingroup TMVA
36
37 Virtual base Class for all MVA method
38
39 MethodBase hosts several specific evaluation methods.
40
41 The kind of MVA that provides optimal performance in an analysis strongly
42 depends on the particular application. The evaluation factory provides a
43 number of numerical benchmark results to directly assess the performance
44 of the MVA training on the independent test sample. These are:
45
46 - The _signal efficiency_ at three representative background efficiencies
47 (which is 1 &minus; rejection).
48 - The _significance_ of an MVA estimator, defined by the difference
49 between the MVA mean values for signal and background, divided by the
50 quadratic sum of their root mean squares.
51 - The _separation_ of an MVA _x_, defined by the integral
52 \f[
53 \frac{1}{2} \int \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx
54 \f]
55 where
56 \f$ S(x) \f$ and \f$ B(x) \f$ are the signal and background distributions,
57 respectively. The separation is zero for identical signal and background MVA
58 shapes, and it is one for disjunctive shapes.
59 - The average, \f$ \int x \mu (S(x)) dx \f$, of the signal \f$ \mu_{transform} \f$.
60 The \f$ \mu_{transform} \f$ of an MVA denotes the transformation that yields
61 a uniform background distribution. In this way, the signal distributions
62 \f$ S(x) \f$ can be directly compared among the various MVAs. The stronger
63 \f$ S(x) \f$ peaks towards one, the better is the discrimination of the MVA.
64 The \f$ \mu_{transform} \f$ is
65 [documented here](http://tel.ccsd.cnrs.fr/documents/archives0/00/00/29/91/index_fr.html).
66
67 The MVA standard output also prints the linear correlation coefficients between
68 signal and background, which can be useful to eliminate variables that exhibit too
69 strong correlations.
70*/
71
72#include "TMVA/MethodBase.h"
73
74#include "TMVA/Config.h"
75#include "TMVA/Configurable.h"
76#include "TMVA/DataSetInfo.h"
77#include "TMVA/DataSet.h"
78#include "TMVA/Factory.h"
79#include "TMVA/IMethod.h"
80#include "TMVA/MsgLogger.h"
81#include "TMVA/PDF.h"
82#include "TMVA/Ranking.h"
83#include "TMVA/DataLoader.h"
84#include "TMVA/Tools.h"
85#include "TMVA/Results.h"
89#include "TMVA/RootFinder.h"
90#include "TMVA/Timer.h"
91#include "TMVA/TSpline1.h"
92#include "TMVA/Types.h"
96#include "TMVA/VariableInfo.h"
100#include "TMVA/Version.h"
101
102#include "TROOT.h"
103#include "TSystem.h"
104#include "TObjString.h"
105#include "TQObject.h"
106#include "TSpline.h"
107#include "TMatrix.h"
108#include "TMath.h"
109#include "TH1F.h"
110#include "TH2F.h"
111#include "TFile.h"
112#include "TGraph.h"
113#include "TXMLEngine.h"
114
115#include <iomanip>
116#include <iostream>
117#include <fstream>
118#include <sstream>
119#include <cstdlib>
120#include <algorithm>
121#include <limits>
122
123
125
126using std::endl;
127using std::atof;
128
129//const Int_t MethodBase_MaxIterations_ = 200;
131
132//const Int_t NBIN_HIST_PLOT = 100;
133const Int_t NBIN_HIST_HIGH = 10000;
134
135#ifdef _WIN32
136/* Disable warning C4355: 'this' : used in base member initializer list */
137#pragma warning ( disable : 4355 )
138#endif
139
140
141#include "TMultiGraph.h"
142
143////////////////////////////////////////////////////////////////////////////////
144/// standard constructor
145
147{
148 fNumGraphs = 0;
149 fIndex = 0;
150}
151
152////////////////////////////////////////////////////////////////////////////////
153/// standard destructor
155{
156 if (fMultiGraph){
157 delete fMultiGraph;
158 fMultiGraph = nullptr;
159 }
160 return;
161}
162
163////////////////////////////////////////////////////////////////////////////////
164/// This function gets some title and it creates a TGraph for every title.
165/// It also sets up the style for every TGraph. All graphs are added to a single TMultiGraph.
166///
167/// \param[in] graphTitles vector of titles
168
169void TMVA::IPythonInteractive::Init(std::vector<TString>& graphTitles)
170{
171 if (fNumGraphs!=0){
172 std::cerr << kERROR << "IPythonInteractive::Init: already initialized..." << std::endl;
173 return;
174 }
175 Int_t color = 2;
176 for(auto& title : graphTitles){
177 fGraphs.push_back( new TGraph() );
178 fGraphs.back()->SetTitle(title);
179 fGraphs.back()->SetName(title);
180 fGraphs.back()->SetFillColor(color);
181 fGraphs.back()->SetLineColor(color);
182 fGraphs.back()->SetMarkerColor(color);
183 fMultiGraph->Add(fGraphs.back());
184 color += 2;
185 fNumGraphs += 1;
186 }
187 return;
188}
189
190////////////////////////////////////////////////////////////////////////////////
191/// This function sets the point number to 0 for all graphs.
192
194{
195 for(Int_t i=0; i<fNumGraphs; i++){
196 fGraphs[i]->Set(0);
197 }
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// This function is used only in 2 TGraph case, and it will add new data points to graphs.
202///
203/// \param[in] x the x coordinate
204/// \param[in] y1 the y coordinate for the first TGraph
205/// \param[in] y2 the y coordinate for the second TGraph
206
208{
209 fGraphs[0]->Set(fIndex+1);
210 fGraphs[1]->Set(fIndex+1);
211 fGraphs[0]->SetPoint(fIndex, x, y1);
212 fGraphs[1]->SetPoint(fIndex, x, y2);
213 fIndex++;
214 return;
215}
216
217////////////////////////////////////////////////////////////////////////////////
218/// This function can add data points to as many TGraphs as we have.
219///
220/// \param[in] dat vector of data points. The dat[0] contains the x coordinate,
221/// dat[1] contains the y coordinate for first TGraph, dat[2] for second, ...
222
223void TMVA::IPythonInteractive::AddPoint(std::vector<Double_t>& dat)
224{
225 for(Int_t i=0; i<fNumGraphs;i++){
226 fGraphs[i]->Set(fIndex+1);
227 fGraphs[i]->SetPoint(fIndex, dat[0], dat[i+1]);
228 }
229 fIndex++;
230 return;
231}
232
233
234////////////////////////////////////////////////////////////////////////////////
235/// standard constructor
236
238 Types::EMVA methodType,
239 const TString& methodTitle,
240 DataSetInfo& dsi,
241 const TString& theOption) :
242 IMethod(),
243 Configurable ( theOption ),
244 fTmpEvent ( 0 ),
245 fRanking ( 0 ),
246 fInputVars ( 0 ),
247 fAnalysisType ( Types::kNoAnalysisType ),
248 fRegressionReturnVal ( 0 ),
249 fMulticlassReturnVal ( 0 ),
250 fDataSetInfo ( dsi ),
251 fSignalReferenceCut ( 0.5 ),
252 fSignalReferenceCutOrientation( 1. ),
253 fVariableTransformType ( Types::kSignal ),
254 fJobName ( jobName ),
255 fMethodName ( methodTitle ),
256 fMethodType ( methodType ),
257 fTestvar ( "" ),
258 fTMVATrainingVersion ( TMVA_VERSION_CODE ),
259 fROOTTrainingVersion ( ROOT_VERSION_CODE ),
260 fConstructedFromWeightFile ( kFALSE ),
261 fBaseDir ( 0 ),
262 fMethodBaseDir ( 0 ),
263 fFile ( 0 ),
264 fSilentFile (kFALSE),
265 fModelPersistence (kTRUE),
266 fWeightFile ( "" ),
267 fEffS ( 0 ),
268 fDefaultPDF ( 0 ),
269 fMVAPdfS ( 0 ),
270 fMVAPdfB ( 0 ),
271 fSplS ( 0 ),
272 fSplB ( 0 ),
273 fSpleffBvsS ( 0 ),
274 fSplTrainS ( 0 ),
275 fSplTrainB ( 0 ),
276 fSplTrainEffBvsS ( 0 ),
277 fVarTransformString ( "None" ),
278 fTransformationPointer ( 0 ),
279 fTransformation ( dsi, methodTitle ),
280 fVerbose ( kFALSE ),
281 fVerbosityLevelString ( "Default" ),
282 fHelp ( kFALSE ),
283 fHasMVAPdfs ( kFALSE ),
284 fIgnoreNegWeightsInTraining( kFALSE ),
285 fSignalClass ( 0 ),
286 fBackgroundClass ( 0 ),
287 fSplRefS ( 0 ),
288 fSplRefB ( 0 ),
289 fSplTrainRefS ( 0 ),
290 fSplTrainRefB ( 0 ),
291 fSetupCompleted (kFALSE)
292{
295
296// // default extension for weight files
297}
298
299////////////////////////////////////////////////////////////////////////////////
300/// constructor used for Testing + Application of the MVA,
301/// only (no training), using given WeightFiles
302
304 DataSetInfo& dsi,
305 const TString& weightFile ) :
306 IMethod(),
307 Configurable(""),
308 fTmpEvent ( 0 ),
309 fRanking ( 0 ),
310 fInputVars ( 0 ),
311 fAnalysisType ( Types::kNoAnalysisType ),
312 fRegressionReturnVal ( 0 ),
313 fMulticlassReturnVal ( 0 ),
314 fDataSetInfo ( dsi ),
315 fSignalReferenceCut ( 0.5 ),
316 fVariableTransformType ( Types::kSignal ),
317 fJobName ( "" ),
318 fMethodName ( "MethodBase" ),
319 fMethodType ( methodType ),
320 fTestvar ( "" ),
321 fTMVATrainingVersion ( 0 ),
322 fROOTTrainingVersion ( 0 ),
323 fConstructedFromWeightFile ( kTRUE ),
324 fBaseDir ( 0 ),
325 fMethodBaseDir ( 0 ),
326 fFile ( 0 ),
327 fSilentFile (kFALSE),
328 fModelPersistence (kTRUE),
329 fWeightFile ( weightFile ),
330 fEffS ( 0 ),
331 fDefaultPDF ( 0 ),
332 fMVAPdfS ( 0 ),
333 fMVAPdfB ( 0 ),
334 fSplS ( 0 ),
335 fSplB ( 0 ),
336 fSpleffBvsS ( 0 ),
337 fSplTrainS ( 0 ),
338 fSplTrainB ( 0 ),
339 fSplTrainEffBvsS ( 0 ),
340 fVarTransformString ( "None" ),
341 fTransformationPointer ( 0 ),
342 fTransformation ( dsi, "" ),
343 fVerbose ( kFALSE ),
344 fVerbosityLevelString ( "Default" ),
345 fHelp ( kFALSE ),
346 fHasMVAPdfs ( kFALSE ),
347 fIgnoreNegWeightsInTraining( kFALSE ),
348 fSignalClass ( 0 ),
349 fBackgroundClass ( 0 ),
350 fSplRefS ( 0 ),
351 fSplRefB ( 0 ),
352 fSplTrainRefS ( 0 ),
353 fSplTrainRefB ( 0 ),
354 fSetupCompleted (kFALSE)
355{
357// // constructor used for Testing + Application of the MVA,
358// // only (no training), using given WeightFiles
359}
360
361////////////////////////////////////////////////////////////////////////////////
362/// destructor
363
365{
366 // destructor
367 if (!fSetupCompleted) Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling destructor of method which got never setup" << Endl;
368
369 // destructor
370 if (fInputVars != 0) { fInputVars->clear(); delete fInputVars; }
371 if (fRanking != 0) delete fRanking;
372
373 // PDFs
374 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
375 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
376 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
377
378 // Splines
379 if (fSplS) { delete fSplS; fSplS = 0; }
380 if (fSplB) { delete fSplB; fSplB = 0; }
381 if (fSpleffBvsS) { delete fSpleffBvsS; fSpleffBvsS = 0; }
382 if (fSplRefS) { delete fSplRefS; fSplRefS = 0; }
383 if (fSplRefB) { delete fSplRefB; fSplRefB = 0; }
384 if (fSplTrainRefS) { delete fSplTrainRefS; fSplTrainRefS = 0; }
385 if (fSplTrainRefB) { delete fSplTrainRefB; fSplTrainRefB = 0; }
386 if (fSplTrainEffBvsS) { delete fSplTrainEffBvsS; fSplTrainEffBvsS = 0; }
387
388 for (size_t i = 0; i < fEventCollections.size(); i++ ) {
389 if (fEventCollections.at(i)) {
390 for (std::vector<Event*>::const_iterator it = fEventCollections.at(i)->begin();
391 it != fEventCollections.at(i)->end(); ++it) {
392 delete (*it);
393 }
394 delete fEventCollections.at(i);
395 fEventCollections.at(i) = nullptr;
396 }
397 }
398
399 if (fRegressionReturnVal) delete fRegressionReturnVal;
400 if (fMulticlassReturnVal) delete fMulticlassReturnVal;
401}
402
403////////////////////////////////////////////////////////////////////////////////
404/// setup of methods
405
407{
408 // setup of methods
409
410 if (fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling SetupMethod for the second time" << Endl;
411 InitBase();
412 DeclareBaseOptions();
413 Init();
414 DeclareOptions();
415 fSetupCompleted = kTRUE;
416}
417
418////////////////////////////////////////////////////////////////////////////////
419/// process all options
420/// the "CheckForUnusedOptions" is done in an independent call, since it may be overridden by derived class
421/// (sometimes, eg, fitters are used which can only be implemented during training phase)
422
424{
425 ProcessBaseOptions();
426 ProcessOptions();
427}
428
429////////////////////////////////////////////////////////////////////////////////
430/// check may be overridden by derived class
431/// (sometimes, eg, fitters are used which can only be implemented during training phase)
432
434{
435 CheckForUnusedOptions();
436}
437
438////////////////////////////////////////////////////////////////////////////////
439/// default initialization called by all constructors
440
442{
443 SetConfigDescription( "Configuration options for classifier architecture and tuning" );
444
446 fNbinsMVAoutput = gConfig().fVariablePlotting.fNbinsMVAoutput;
447 fNbinsH = NBIN_HIST_HIGH;
448
449 fSplTrainS = 0;
450 fSplTrainB = 0;
451 fSplTrainEffBvsS = 0;
452 fMeanS = -1;
453 fMeanB = -1;
454 fRmsS = -1;
455 fRmsB = -1;
456 fXmin = DBL_MAX;
457 fXmax = -DBL_MAX;
458 fTxtWeightsOnly = kTRUE;
459 fSplRefS = 0;
460 fSplRefB = 0;
461
462 fTrainTime = -1.;
463 fTestTime = -1.;
464
465 fRanking = 0;
466
467 // temporary until the move to DataSet is complete
468 fInputVars = new std::vector<TString>;
469 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
470 fInputVars->push_back(DataInfo().GetVariableInfo(ivar).GetLabel());
471 }
472 fRegressionReturnVal = 0;
473 fMulticlassReturnVal = 0;
474
475 fEventCollections.resize( 2 );
476 fEventCollections.at(0) = 0;
477 fEventCollections.at(1) = 0;
478
479 // retrieve signal and background class index
480 if (DataInfo().GetClassInfo("Signal") != 0) {
481 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
482 }
483 if (DataInfo().GetClassInfo("Background") != 0) {
484 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
485 }
486
487 SetConfigDescription( "Configuration options for MVA method" );
488 SetConfigName( TString("Method") + GetMethodTypeName() );
489}
490
491////////////////////////////////////////////////////////////////////////////////
492/// define the options (their key words) that can be set in the option string
493/// here the options valid for ALL MVA methods are declared.
494///
495/// know options:
496///
497/// - VariableTransform=None,Decorrelated,PCA to use transformed variables
498/// instead of the original ones
499/// - VariableTransformType=Signal,Background which decorrelation matrix to use
500/// in the method. Only the Likelihood
501/// Method can make proper use of independent
502/// transformations of signal and background
503/// - fNbinsMVAPdf = 50 Number of bins used to create a PDF of MVA
504/// - fNsmoothMVAPdf = 2 Number of times a histogram is smoothed before creating the PDF
505/// - fHasMVAPdfs create PDFs for the MVA outputs
506/// - V for Verbose output (!V) for non verbos
507/// - H for Help message
508
510{
511 DeclareOptionRef( fVerbose, "V", "Verbose output (short form of \"VerbosityLevel\" below - overrides the latter one)" );
512
513 DeclareOptionRef( fVerbosityLevelString="Default", "VerbosityLevel", "Verbosity level" );
514 AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
515 AddPreDefVal( TString("Debug") );
516 AddPreDefVal( TString("Verbose") );
517 AddPreDefVal( TString("Info") );
518 AddPreDefVal( TString("Warning") );
519 AddPreDefVal( TString("Error") );
520 AddPreDefVal( TString("Fatal") );
521
522 // If True (default): write all training results (weights) as text files only;
523 // if False: write also in ROOT format (not available for all methods - will abort if not
524 fTxtWeightsOnly = kTRUE; // OBSOLETE !!!
525 fNormalise = kFALSE; // OBSOLETE !!!
526
527 DeclareOptionRef( fVarTransformString, "VarTransform", "List of variable transformations performed before training, e.g., \"D_Background,P_Signal,G,N_AllClasses\" for: \"Decorrelation, PCA-transformation, Gaussianisation, Normalisation, each for the given class of events ('AllClasses' denotes all events of all classes, if no class indication is given, 'All' is assumed)\"" );
528
529 DeclareOptionRef( fHelp, "H", "Print method-specific help message" );
530
531 DeclareOptionRef( fHasMVAPdfs, "CreateMVAPdfs", "Create PDFs for classifier outputs (signal and background)" );
532
533 DeclareOptionRef( fIgnoreNegWeightsInTraining, "IgnoreNegWeightsInTraining",
534 "Events with negative weights are ignored in the training (but are included for testing and performance evaluation)" );
535}
536
537////////////////////////////////////////////////////////////////////////////////
538/// the option string is decoded, for available options see "DeclareOptions"
539
541{
542 if (HasMVAPdfs()) {
543 // setting the default bin num... maybe should be static ? ==> Please no static (JS)
544 // You can't use the logger in the constructor!!! Log() << kINFO << "Create PDFs" << Endl;
545 // reading every PDF's definition and passing the option string to the next one to be read and marked
546 fDefaultPDF = new PDF( TString(GetName())+"_PDF", GetOptions(), "MVAPdf" );
547 fDefaultPDF->DeclareOptions();
548 fDefaultPDF->ParseOptions();
549 fDefaultPDF->ProcessOptions();
550 fMVAPdfB = new PDF( TString(GetName())+"_PDFBkg", fDefaultPDF->GetOptions(), "MVAPdfBkg", fDefaultPDF );
551 fMVAPdfB->DeclareOptions();
552 fMVAPdfB->ParseOptions();
553 fMVAPdfB->ProcessOptions();
554 fMVAPdfS = new PDF( TString(GetName())+"_PDFSig", fMVAPdfB->GetOptions(), "MVAPdfSig", fDefaultPDF );
555 fMVAPdfS->DeclareOptions();
556 fMVAPdfS->ParseOptions();
557 fMVAPdfS->ProcessOptions();
558
559 // the final marked option string is written back to the original methodbase
560 SetOptions( fMVAPdfS->GetOptions() );
561 }
562
563 TMVA::CreateVariableTransforms( fVarTransformString,
564 DataInfo(),
565 GetTransformationHandler(),
566 Log() );
567
568 if (!HasMVAPdfs()) {
569 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
570 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
571 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
572 }
573
574 if (fVerbose) { // overwrites other settings
575 fVerbosityLevelString = TString("Verbose");
576 Log().SetMinType( kVERBOSE );
577 }
578 else if (fVerbosityLevelString == "Debug" ) Log().SetMinType( kDEBUG );
579 else if (fVerbosityLevelString == "Verbose" ) Log().SetMinType( kVERBOSE );
580 else if (fVerbosityLevelString == "Info" ) Log().SetMinType( kINFO );
581 else if (fVerbosityLevelString == "Warning" ) Log().SetMinType( kWARNING );
582 else if (fVerbosityLevelString == "Error" ) Log().SetMinType( kERROR );
583 else if (fVerbosityLevelString == "Fatal" ) Log().SetMinType( kFATAL );
584 else if (fVerbosityLevelString != "Default" ) {
585 Log() << kFATAL << "<ProcessOptions> Verbosity level type '"
586 << fVerbosityLevelString << "' unknown." << Endl;
587 }
588 Event::SetIgnoreNegWeightsInTraining(fIgnoreNegWeightsInTraining);
589}
590
591////////////////////////////////////////////////////////////////////////////////
592/// options that are used ONLY for the READER to ensure backward compatibility
593/// they are hence without any effect (the reader is only reading the training
594/// options that HAD been used at the training of the .xml weight file at hand
595
597{
598 DeclareOptionRef( fNormalise=kFALSE, "Normalise", "Normalise input variables" ); // don't change the default !!!
599 DeclareOptionRef( fUseDecorr=kFALSE, "D", "Use-decorrelated-variables flag" );
600 DeclareOptionRef( fVariableTransformTypeString="Signal", "VarTransformType",
601 "Use signal or background events to derive for variable transformation (the transformation is applied on both types of, course)" );
602 AddPreDefVal( TString("Signal") );
603 AddPreDefVal( TString("Background") );
604 DeclareOptionRef( fTxtWeightsOnly=kTRUE, "TxtWeightFilesOnly", "If True: write all training results (weights) as text files (False: some are written in ROOT format)" );
605 // Why on earth ?? was this here? Was the verbosity level option meant to 'disappear? Not a good idea i think..
606 // DeclareOptionRef( fVerbosityLevelString="Default", "VerboseLevel", "Verbosity level" );
607 // AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
608 // AddPreDefVal( TString("Debug") );
609 // AddPreDefVal( TString("Verbose") );
610 // AddPreDefVal( TString("Info") );
611 // AddPreDefVal( TString("Warning") );
612 // AddPreDefVal( TString("Error") );
613 // AddPreDefVal( TString("Fatal") );
614 DeclareOptionRef( fNbinsMVAPdf = 60, "NbinsMVAPdf", "Number of bins used for the PDFs of classifier outputs" );
615 DeclareOptionRef( fNsmoothMVAPdf = 2, "NsmoothMVAPdf", "Number of smoothing iterations for classifier PDFs" );
616}
617
618
619////////////////////////////////////////////////////////////////////////////////
620/// call the Optimizer with the set of parameters and ranges that
621/// are meant to be tuned.
622
623std::map<TString,Double_t> TMVA::MethodBase::OptimizeTuningParameters(TString /* fomType */ , TString /* fitType */)
624{
625 // this is just a dummy... needs to be implemented for each method
626 // individually (as long as we don't have it automatized via the
627 // configuration string
628
629 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Parameter optimization is not yet implemented for method "
630 << GetName() << Endl;
631 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Currently we need to set hardcoded which parameter is tuned in which ranges"<<Endl;
632
633 std::map<TString,Double_t> tunedParameters;
634 tunedParameters.size(); // just to get rid of "unused" warning
635 return tunedParameters;
636
637}
638
639////////////////////////////////////////////////////////////////////////////////
640/// set the tuning parameters according to the argument
641/// This is just a dummy .. have a look at the MethodBDT how you could
642/// perhaps implement the same thing for the other Classifiers..
643
644void TMVA::MethodBase::SetTuneParameters(std::map<TString,Double_t> /* tuneParameters */)
645{
646}
647
648////////////////////////////////////////////////////////////////////////////////
649
651{
652 Data()->SetCurrentType(Types::kTraining);
653 Event::SetIsTraining(kTRUE); // used to set negative event weights to zero if chosen to do so
654
655 // train the MVA method
656 if (Help()) PrintHelpMessage();
657
658 // all histograms should be created in the method's subdirectory
659 if(!IsSilentFile()) BaseDir()->cd();
660
661 // once calculate all the transformation (e.g. the sequence of Decorr:Gauss:Decorr)
662 // needed for this classifier
663 GetTransformationHandler().CalcTransformations(Data()->GetEventCollection());
664
665 // call training of derived MVA
666 Log() << kDEBUG //<<Form("\tDataset[%s] : ",DataInfo().GetName())
667 << "Begin training" << Endl;
668 Long64_t nEvents = Data()->GetNEvents();
669 Timer traintimer( nEvents, GetName(), kTRUE );
670 Train();
671 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName()
672 << "\tEnd of training " << Endl;
673 SetTrainTime(traintimer.ElapsedSeconds());
674 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
675 << "Elapsed time for training with " << nEvents << " events: "
676 << traintimer.GetElapsedTime() << " " << Endl;
677
678 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
679 << "\tCreate MVA output for ";
680
681 // create PDFs for the signal and background MVA distributions (if required)
682 if (DoMulticlass()) {
683 Log() <<Form("[%s] : ",DataInfo().GetName())<< "Multiclass classification on training sample" << Endl;
684 AddMulticlassOutput(Types::kTraining);
685 }
686 else if (!DoRegression()) {
687
688 Log() <<Form("[%s] : ",DataInfo().GetName())<< "classification on training sample" << Endl;
689 AddClassifierOutput(Types::kTraining);
690 if (HasMVAPdfs()) {
691 CreateMVAPdfs();
692 AddClassifierOutputProb(Types::kTraining);
693 }
694
695 } else {
696
697 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "regression on training sample" << Endl;
698 AddRegressionOutput( Types::kTraining );
699
700 if (HasMVAPdfs() ) {
701 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create PDFs" << Endl;
702 CreateMVAPdfs();
703 }
704 }
705
706 // write the current MVA state into stream
707 // produced are one text file and one ROOT file
708 if (fModelPersistence ) WriteStateToFile();
709
710 // produce standalone make class (presently only supported for classification)
711 if ((!DoRegression()) && (fModelPersistence)) MakeClass();
712
713 // write additional monitoring histograms to main target file (not the weight file)
714 // again, make sure the histograms go into the method's subdirectory
715 if(!IsSilentFile())
716 {
717 BaseDir()->cd();
718 WriteMonitoringHistosToFile();
719 }
720}
721
722////////////////////////////////////////////////////////////////////////////////
723
725{
726 if (!DoRegression()) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Trying to use GetRegressionDeviation() with a classification job" << Endl;
727 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
728 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), Types::kTesting, Types::kRegression);
729 bool truncate = false;
730 TH1F* h1 = regRes->QuadraticDeviation( tgtNum , truncate, 1.);
731 stddev = sqrt(h1->GetMean());
732 truncate = true;
733 Double_t yq[1], xq[]={0.9};
734 h1->GetQuantiles(1,yq,xq);
735 TH1F* h2 = regRes->QuadraticDeviation( tgtNum , truncate, yq[0]);
736 stddev90Percent = sqrt(h2->GetMean());
737 delete h1;
738 delete h2;
739}
740
741////////////////////////////////////////////////////////////////////////////////
742/// prepare tree branch with the method's discriminating variable
743
745{
746 Data()->SetCurrentType(type);
747
748 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
749
750 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), type, Types::kRegression);
751
752 Long64_t nEvents = Data()->GetNEvents();
753
754 // use timer
755 Timer timer( nEvents, GetName(), kTRUE );
756 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
757 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
758
759 regRes->Resize( nEvents );
760
761 // Drawing the progress bar every event was causing a huge slowdown in the evaluation time
762 // So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
763
764 Int_t totalProgressDraws = 100; // total number of times to update the progress bar
765 Int_t drawProgressEvery = 1; // draw every nth event such that we have a total of totalProgressDraws
766 if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
767
768 for (Int_t ievt=0; ievt<nEvents; ievt++) {
769
770 Data()->SetCurrentEvent(ievt);
771 std::vector< Float_t > vals = GetRegressionValues();
772 regRes->SetValue( vals, ievt );
773
774 // Only draw the progress bar once in a while, doing this every event causes the evaluation to be ridiculously slow
775 if(ievt % drawProgressEvery == 0 || ievt==nEvents-1) timer.DrawProgressBar( ievt );
776 }
777
778 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
779 << "Elapsed time for evaluation of " << nEvents << " events: "
780 << timer.GetElapsedTime() << " " << Endl;
781
782 // store time used for testing
784 SetTestTime(timer.ElapsedSeconds());
785
786 TString histNamePrefix(GetTestvarName());
787 histNamePrefix += (type==Types::kTraining?"train":"test");
788 regRes->CreateDeviationHistograms( histNamePrefix );
789}
790
791////////////////////////////////////////////////////////////////////////////////
792/// prepare tree branch with the method's discriminating variable
793
795{
796 Data()->SetCurrentType(type);
797
798 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
799
800 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
801 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in AddMulticlassOutput, exiting."<<Endl;
802
803 Long64_t nEvents = Data()->GetNEvents();
804
805 // use timer
806 Timer timer( nEvents, GetName(), kTRUE );
807
808 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Multiclass evaluation of " << GetMethodName() << " on "
809 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
810
811 resMulticlass->Resize( nEvents );
812 Int_t modulo = Int_t(nEvents/100) + 1;
813 for (Int_t ievt=0; ievt<nEvents; ievt++) {
814 Data()->SetCurrentEvent(ievt);
815 std::vector< Float_t > vals = GetMulticlassValues();
816 resMulticlass->SetValue( vals, ievt );
817 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
818 }
819
820 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
821 << "Elapsed time for evaluation of " << nEvents << " events: "
822 << timer.GetElapsedTime() << " " << Endl;
823
824 // store time used for testing
826 SetTestTime(timer.ElapsedSeconds());
827
828 TString histNamePrefix(GetTestvarName());
829 histNamePrefix += (type==Types::kTraining?"_Train":"_Test");
830
831 resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbinsMVAoutput, fNbinsH );
832 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefix);
833}
834
835////////////////////////////////////////////////////////////////////////////////
836
837void TMVA::MethodBase::NoErrorCalc(Double_t* const err, Double_t* const errUpper) {
838 if (err) *err=-1;
839 if (errUpper) *errUpper=-1;
840}
841
842////////////////////////////////////////////////////////////////////////////////
843
844Double_t TMVA::MethodBase::GetMvaValue( const Event* const ev, Double_t* err, Double_t* errUpper ) {
845 fTmpEvent = ev;
846 Double_t val = GetMvaValue(err, errUpper);
847 fTmpEvent = 0;
848 return val;
849}
850
851////////////////////////////////////////////////////////////////////////////////
852/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
853/// for a quick determination if an event would be selected as signal or background
854
856 return GetMvaValue()*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
857}
858////////////////////////////////////////////////////////////////////////////////
859/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
860/// for a quick determination if an event with this mva output value would be selected as signal or background
861
863 return mvaVal*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
864}
865
866////////////////////////////////////////////////////////////////////////////////
867/// prepare tree branch with the method's discriminating variable
868
870{
871 Data()->SetCurrentType(type);
872
873 ResultsClassification* clRes =
874 (ResultsClassification*)Data()->GetResults(GetMethodName(), type, Types::kClassification );
875
876 Long64_t nEvents = Data()->GetNEvents();
877 clRes->Resize( nEvents );
878
879 // use timer
880 Timer timer( nEvents, GetName(), kTRUE );
881 std::vector<Double_t> mvaValues = GetMvaValues(0, nEvents, true);
882
883 // store time used for testing
885 SetTestTime(timer.ElapsedSeconds());
886
887 // load mva values and type to results object
888 for (Int_t ievt = 0; ievt < nEvents; ievt++) {
889 // note we do not need the trasformed event to get the signal/background information
890 // by calling Data()->GetEvent instead of this->GetEvent we access the untransformed one
891 auto ev = Data()->GetEvent(ievt);
892 clRes->SetValue(mvaValues[ievt], ievt, DataInfo().IsSignal(ev));
893 }
894}
895
896////////////////////////////////////////////////////////////////////////////////
897/// get all the MVA values for the events of the current Data type
898std::vector<Double_t> TMVA::MethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
899{
900
901 Long64_t nEvents = Data()->GetNEvents();
902 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
903 if (firstEvt < 0) firstEvt = 0;
904 std::vector<Double_t> values(lastEvt-firstEvt);
905 // log in case of looping on all the events
906 nEvents = values.size();
907
908 // use timer
909 Timer timer( nEvents, GetName(), kTRUE );
910
911 if (logProgress)
912 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
913 << "Evaluation of " << GetMethodName() << " on "
914 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
915 << " sample (" << nEvents << " events)" << Endl;
916
917 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
918 Data()->SetCurrentEvent(ievt);
919 values[ievt] = GetMvaValue();
920
921 // print progress
922 if (logProgress) {
923 Int_t modulo = Int_t(nEvents/100);
924 if (modulo <= 0 ) modulo = 1;
925 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
926 }
927 }
928 if (logProgress) {
929 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
930 << "Elapsed time for evaluation of " << nEvents << " events: "
931 << timer.GetElapsedTime() << " " << Endl;
932 }
933
934 return values;
935}
936
937////////////////////////////////////////////////////////////////////////////////
938/// get all the MVA values for the events of the given Data type
939// (this is used by Method Category and it does not need to be re-implmented by derived classes )
940std::vector<Double_t> TMVA::MethodBase::GetDataMvaValues(DataSet * data, Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
941{
942 fTmpData = data;
943 auto result = GetMvaValues(firstEvt, lastEvt, logProgress);
944 fTmpData = nullptr;
945 return result;
946}
947
948////////////////////////////////////////////////////////////////////////////////
949/// prepare tree branch with the method's discriminating variable
950
952{
953 Data()->SetCurrentType(type);
954
955 ResultsClassification* mvaProb =
956 (ResultsClassification*)Data()->GetResults(TString("prob_")+GetMethodName(), type, Types::kClassification );
957
958 Long64_t nEvents = Data()->GetNEvents();
959
960 // use timer
961 Timer timer( nEvents, GetName(), kTRUE );
962
963 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
964 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
965
966 mvaProb->Resize( nEvents );
967 Int_t modulo = Int_t(nEvents/100);
968 if (modulo <= 0 ) modulo = 1;
969 for (Int_t ievt=0; ievt<nEvents; ievt++) {
970
971 Data()->SetCurrentEvent(ievt);
972 Float_t proba = ((Float_t)GetProba( GetMvaValue(), 0.5 ));
973 if (proba < 0) break;
974 mvaProb->SetValue( proba, ievt, DataInfo().IsSignal( Data()->GetEvent()) );
975
976 // print progress
977 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
978 }
979
980 Log() << kDEBUG <<Form("Dataset[%s] : ",DataInfo().GetName())
981 << "Elapsed time for evaluation of " << nEvents << " events: "
982 << timer.GetElapsedTime() << " " << Endl;
983}
984
985////////////////////////////////////////////////////////////////////////////////
986/// calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
987///
988/// - bias = average deviation
989/// - dev = average absolute deviation
990/// - rms = rms of deviation
991
993 Double_t& dev, Double_t& devT,
994 Double_t& rms, Double_t& rmsT,
995 Double_t& mInf, Double_t& mInfT,
996 Double_t& corr,
998{
999 Types::ETreeType savedType = Data()->GetCurrentType();
1000 Data()->SetCurrentType(type);
1001
1002 bias = 0; biasT = 0; dev = 0; devT = 0; rms = 0; rmsT = 0;
1003 Double_t sumw = 0;
1004 Double_t m1 = 0, m2 = 0, s1 = 0, s2 = 0, s12 = 0; // for correlation
1005 const Int_t nevt = GetNEvents();
1006 Float_t* rV = new Float_t[nevt];
1007 Float_t* tV = new Float_t[nevt];
1008 Float_t* wV = new Float_t[nevt];
1009 Float_t xmin = 1e30, xmax = -1e30;
1010 Log() << kINFO << "Calculate regression for all events" << Endl;
1011 Timer timer( nevt, GetName(), kTRUE );
1012 Long64_t modulo = Long64_t(nevt / 100) + 1;
1013 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1014
1015 const Event* ev = Data()->GetEvent(ievt); // NOTE: need untransformed event here !
1016 Float_t t = ev->GetTarget(0);
1017 Float_t w = ev->GetWeight();
1018 Float_t r = GetRegressionValues()[0];
1019 Float_t d = (r-t);
1020
1021 // find min/max
1024
1025 // store for truncated RMS computation
1026 rV[ievt] = r;
1027 tV[ievt] = t;
1028 wV[ievt] = w;
1029
1030 // compute deviation-squared
1031 sumw += w;
1032 bias += w * d;
1033 dev += w * TMath::Abs(d);
1034 rms += w * d * d;
1035
1036 // compute correlation between target and regression estimate
1037 m1 += t*w; s1 += t*t*w;
1038 m2 += r*w; s2 += r*r*w;
1039 s12 += t*r;
1040 // print progress
1041 if (ievt % modulo == 0)
1042 timer.DrawProgressBar(ievt);
1043 }
1044 timer.DrawProgressBar(nevt - 1);
1045 Log() << kINFO << "Elapsed time for evaluation of " << nevt << " events: "
1046 << timer.GetElapsedTime() << " " << Endl;
1047
1048 // standard quantities
1049 bias /= sumw;
1050 dev /= sumw;
1051 rms /= sumw;
1052 rms = TMath::Sqrt(rms - bias*bias);
1053
1054 // correlation
1055 m1 /= sumw;
1056 m2 /= sumw;
1057 corr = s12/sumw - m1*m2;
1058 corr /= TMath::Sqrt( (s1/sumw - m1*m1) * (s2/sumw - m2*m2) );
1059
1060 // create histogram required for computation of mutual information
1061 TH2F* hist = new TH2F( "hist", "hist", 150, xmin, xmax, 100, xmin, xmax );
1062 TH2F* histT = new TH2F( "histT", "histT", 150, xmin, xmax, 100, xmin, xmax );
1063
1064 // compute truncated RMS and fill histogram
1065 Double_t devMax = bias + 2*rms;
1066 Double_t devMin = bias - 2*rms;
1067 sumw = 0;
1068 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1069 Float_t d = (rV[ievt] - tV[ievt]);
1070 hist->Fill( rV[ievt], tV[ievt], wV[ievt] );
1071 if (d >= devMin && d <= devMax) {
1072 sumw += wV[ievt];
1073 biasT += wV[ievt] * d;
1074 devT += wV[ievt] * TMath::Abs(d);
1075 rmsT += wV[ievt] * d * d;
1076 histT->Fill( rV[ievt], tV[ievt], wV[ievt] );
1077 }
1078 }
1079 biasT /= sumw;
1080 devT /= sumw;
1081 rmsT /= sumw;
1082 rmsT = TMath::Sqrt(rmsT - biasT*biasT);
1083 mInf = gTools().GetMutualInformation( *hist );
1084 mInfT = gTools().GetMutualInformation( *histT );
1085
1086 delete hist;
1087 delete histT;
1088
1089 delete [] rV;
1090 delete [] tV;
1091 delete [] wV;
1092
1093 Data()->SetCurrentType(savedType);
1094}
1095
1096
1097////////////////////////////////////////////////////////////////////////////////
1098/// test multiclass classification
1099
1101{
1102 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
1103 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in TestMulticlass, exiting."<<Endl;
1104
1105 // GA evaluation of best cut for sig eff * sig pur. Slow, disabled for now.
1106 // Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for test
1107 // data..." << Endl; for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
1108 // resMulticlass->GetBestMultiClassCuts(icls);
1109 // }
1110
1111 // Create histograms for use in TMVA GUI
1112 TString histNamePrefix(GetTestvarName());
1113 TString histNamePrefixTest{histNamePrefix + "_Test"};
1114 TString histNamePrefixTrain{histNamePrefix + "_Train"};
1115
1116 resMulticlass->CreateMulticlassHistos(histNamePrefixTest, fNbinsMVAoutput, fNbinsH);
1117 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTest);
1118
1119 resMulticlass->CreateMulticlassHistos(histNamePrefixTrain, fNbinsMVAoutput, fNbinsH);
1120 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTrain);
1121}
1122
1123
1124////////////////////////////////////////////////////////////////////////////////
1125/// initialization
1126
1128{
1129 Data()->SetCurrentType(Types::kTesting);
1130
1131 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
1132 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
1133
1134 // sanity checks: tree must exist, and theVar must be in tree
1135 if (0==mvaRes && !(GetMethodTypeName().Contains("Cuts"))) {
1136 Log()<<Form("Dataset[%s] : ",DataInfo().GetName()) << "mvaRes " << mvaRes << " GetMethodTypeName " << GetMethodTypeName()
1137 << " contains " << !(GetMethodTypeName().Contains("Cuts")) << Endl;
1138 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<TestInit> Test variable " << GetTestvarName()
1139 << " not found in tree" << Endl;
1140 }
1141
1142 // basic statistics operations are made in base class
1143 gTools().ComputeStat( GetEventCollection(Types::kTesting), mvaRes->GetValueVector(),
1144 fMeanS, fMeanB, fRmsS, fRmsB, fXmin, fXmax, fSignalClass );
1145
1146 // choose reasonable histogram ranges, by removing outliers
1147 Double_t nrms = 10;
1148 fXmin = TMath::Max( TMath::Min( fMeanS - nrms*fRmsS, fMeanB - nrms*fRmsB ), fXmin );
1149 fXmax = TMath::Min( TMath::Max( fMeanS + nrms*fRmsS, fMeanB + nrms*fRmsB ), fXmax );
1150
1151 // determine cut orientation
1152 fCutOrientation = (fMeanS > fMeanB) ? kPositive : kNegative;
1153
1154 // fill 2 types of histograms for the various analyses
1155 // this one is for actual plotting
1156
1157 Double_t sxmax = fXmax+0.00001;
1158
1159 // classifier response distributions for training sample
1160 // MVA plots used for graphics representation (signal)
1161 TString TestvarName;
1162 if(IsSilentFile()) {
1163 TestvarName = TString::Format("[%s]%s",DataInfo().GetName(),GetTestvarName().Data());
1164 } else {
1165 TestvarName=GetTestvarName();
1166 }
1167 TH1* mva_s = new TH1D( TestvarName + "_S",TestvarName + "_S", fNbinsMVAoutput, fXmin, sxmax );
1168 TH1* mva_b = new TH1D( TestvarName + "_B",TestvarName + "_B", fNbinsMVAoutput, fXmin, sxmax );
1169 mvaRes->Store(mva_s, "MVA_S");
1170 mvaRes->Store(mva_b, "MVA_B");
1171 mva_s->Sumw2();
1172 mva_b->Sumw2();
1173
1174 TH1* proba_s = 0;
1175 TH1* proba_b = 0;
1176 TH1* rarity_s = 0;
1177 TH1* rarity_b = 0;
1178 if (HasMVAPdfs()) {
1179 // P(MVA) plots used for graphics representation
1180 proba_s = new TH1D( TestvarName + "_Proba_S", TestvarName + "_Proba_S", fNbinsMVAoutput, 0.0, 1.0 );
1181 proba_b = new TH1D( TestvarName + "_Proba_B", TestvarName + "_Proba_B", fNbinsMVAoutput, 0.0, 1.0 );
1182 mvaRes->Store(proba_s, "Prob_S");
1183 mvaRes->Store(proba_b, "Prob_B");
1184 proba_s->Sumw2();
1185 proba_b->Sumw2();
1186
1187 // R(MVA) plots used for graphics representation
1188 rarity_s = new TH1D( TestvarName + "_Rarity_S", TestvarName + "_Rarity_S", fNbinsMVAoutput, 0.0, 1.0 );
1189 rarity_b = new TH1D( TestvarName + "_Rarity_B", TestvarName + "_Rarity_B", fNbinsMVAoutput, 0.0, 1.0 );
1190 mvaRes->Store(rarity_s, "Rar_S");
1191 mvaRes->Store(rarity_b, "Rar_B");
1192 rarity_s->Sumw2();
1193 rarity_b->Sumw2();
1194 }
1195
1196 // MVA plots used for efficiency calculations (large number of bins)
1197 TH1* mva_eff_s = new TH1D( TestvarName + "_S_high", TestvarName + "_S_high", fNbinsH, fXmin, sxmax );
1198 TH1* mva_eff_b = new TH1D( TestvarName + "_B_high", TestvarName + "_B_high", fNbinsH, fXmin, sxmax );
1199 mvaRes->Store(mva_eff_s, "MVA_HIGHBIN_S");
1200 mvaRes->Store(mva_eff_b, "MVA_HIGHBIN_B");
1201 mva_eff_s->Sumw2();
1202 mva_eff_b->Sumw2();
1203
1204 // fill the histograms
1205
1206 ResultsClassification* mvaProb = dynamic_cast<ResultsClassification*>
1207 (Data()->GetResults( TString("prob_")+GetMethodName(), Types::kTesting, Types::kMaxAnalysisType ) );
1208
1209 Log() << kHEADER <<Form("[%s] : ",DataInfo().GetName())<< "Loop over test events and fill histograms with classifier response..." << Endl << Endl;
1210 if (mvaProb) Log() << kINFO << "Also filling probability and rarity histograms (on request)..." << Endl;
1211 //std::vector<Bool_t>* mvaResTypes = mvaRes->GetValueVectorTypes();
1212
1213 //LM: this is needed to avoid crashes in ROOCCURVE
1214 if ( mvaRes->GetSize() != GetNEvents() ) {
1215 Log() << kFATAL << TString::Format("Inconsistent result size %lld with number of events %u ", mvaRes->GetSize() , GetNEvents() ) << Endl;
1216 assert(mvaRes->GetSize() == GetNEvents());
1217 }
1218
1219 for (Long64_t ievt=0; ievt<GetNEvents(); ievt++) {
1220
1221 const Event* ev = GetEvent(ievt);
1222 Float_t v = (*mvaRes)[ievt][0];
1223 Float_t w = ev->GetWeight();
1224
1225 if (DataInfo().IsSignal(ev)) {
1226 //mvaResTypes->push_back(kTRUE);
1227 mva_s ->Fill( v, w );
1228 if (mvaProb) {
1229 proba_s->Fill( (*mvaProb)[ievt][0], w );
1230 rarity_s->Fill( GetRarity( v ), w );
1231 }
1232
1233 mva_eff_s ->Fill( v, w );
1234 }
1235 else {
1236 //mvaResTypes->push_back(kFALSE);
1237 mva_b ->Fill( v, w );
1238 if (mvaProb) {
1239 proba_b->Fill( (*mvaProb)[ievt][0], w );
1240 rarity_b->Fill( GetRarity( v ), w );
1241 }
1242 mva_eff_b ->Fill( v, w );
1243 }
1244 }
1245
1246 // uncomment those (and several others if you want unnormalized output
1247 gTools().NormHist( mva_s );
1248 gTools().NormHist( mva_b );
1249 gTools().NormHist( proba_s );
1250 gTools().NormHist( proba_b );
1251 gTools().NormHist( rarity_s );
1252 gTools().NormHist( rarity_b );
1253 gTools().NormHist( mva_eff_s );
1254 gTools().NormHist( mva_eff_b );
1255
1256 // create PDFs from histograms, using default splines, and no additional smoothing
1257 if (fSplS) { delete fSplS; fSplS = 0; }
1258 if (fSplB) { delete fSplB; fSplB = 0; }
1259 fSplS = new PDF( TString(GetName()) + " PDF Sig", mva_s, PDF::kSpline2 );
1260 fSplB = new PDF( TString(GetName()) + " PDF Bkg", mva_b, PDF::kSpline2 );
1261}
1262
1263////////////////////////////////////////////////////////////////////////////////
1264/// general method used in writing the header of the weight files where
1265/// the used variables, variable transformation type etc. is specified
1266
1267void TMVA::MethodBase::WriteStateToStream( std::ostream& tf ) const
1268{
1269 TString prefix = "";
1270 UserGroup_t * userInfo = gSystem->GetUserInfo();
1271
1272 tf << prefix << "#GEN -*-*-*-*-*-*-*-*-*-*-*- general info -*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1273 tf << prefix << "Method : " << GetMethodTypeName() << "::" << GetMethodName() << std::endl;
1274 tf.setf(std::ios::left);
1275 tf << prefix << "TMVA Release : " << std::setw(10) << GetTrainingTMVAVersionString() << " ["
1276 << GetTrainingTMVAVersionCode() << "]" << std::endl;
1277 tf << prefix << "ROOT Release : " << std::setw(10) << GetTrainingROOTVersionString() << " ["
1278 << GetTrainingROOTVersionCode() << "]" << std::endl;
1279 tf << prefix << "Creator : " << userInfo->fUser << std::endl;
1280 tf << prefix << "Date : "; TDatime *d = new TDatime; tf << d->AsString() << std::endl; delete d;
1281 tf << prefix << "Host : " << gSystem->GetBuildNode() << std::endl;
1282 tf << prefix << "Dir : " << gSystem->WorkingDirectory() << std::endl;
1283 tf << prefix << "Training events: " << Data()->GetNTrainingEvents() << std::endl;
1284
1285 TString analysisType(((const_cast<TMVA::MethodBase*>(this)->GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification"));
1286
1287 tf << prefix << "Analysis type : " << "[" << ((GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification") << "]" << std::endl;
1288 tf << prefix << std::endl;
1289
1290 delete userInfo;
1291
1292 // First write all options
1293 tf << prefix << std::endl << prefix << "#OPT -*-*-*-*-*-*-*-*-*-*-*-*- options -*-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1294 WriteOptionsToStream( tf, prefix );
1295 tf << prefix << std::endl;
1296
1297 // Second write variable info
1298 tf << prefix << std::endl << prefix << "#VAR -*-*-*-*-*-*-*-*-*-*-*-* variables *-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1299 WriteVarsToStream( tf, prefix );
1300 tf << prefix << std::endl;
1301}
1302
1303////////////////////////////////////////////////////////////////////////////////
1304/// xml writing
1305
1306void TMVA::MethodBase::AddInfoItem( void* gi, const TString& name, const TString& value) const
1307{
1308 void* it = gTools().AddChild(gi,"Info");
1309 gTools().AddAttr(it,"name", name);
1310 gTools().AddAttr(it,"value", value);
1311}
1312
1313////////////////////////////////////////////////////////////////////////////////
1314
1316 if (analysisType == Types::kRegression) {
1317 AddRegressionOutput( type );
1318 } else if (analysisType == Types::kMulticlass) {
1319 AddMulticlassOutput( type );
1320 } else {
1321 AddClassifierOutput( type );
1322 if (HasMVAPdfs())
1323 AddClassifierOutputProb( type );
1324 }
1325}
1326
1327////////////////////////////////////////////////////////////////////////////////
1328/// general method used in writing the header of the weight files where
1329/// the used variables, variable transformation type etc. is specified
1330
1331void TMVA::MethodBase::WriteStateToXML( void* parent ) const
1332{
1333 if (!parent) return;
1334
1335 UserGroup_t* userInfo = gSystem->GetUserInfo();
1336
1337 void* gi = gTools().AddChild(parent, "GeneralInfo");
1338 AddInfoItem( gi, "TMVA Release", GetTrainingTMVAVersionString() + " [" + gTools().StringFromInt(GetTrainingTMVAVersionCode()) + "]" );
1339 AddInfoItem( gi, "ROOT Release", GetTrainingROOTVersionString() + " [" + gTools().StringFromInt(GetTrainingROOTVersionCode()) + "]");
1340 AddInfoItem( gi, "Creator", userInfo->fUser);
1341 TDatime dt; AddInfoItem( gi, "Date", dt.AsString());
1342 AddInfoItem( gi, "Host", gSystem->GetBuildNode() );
1343 AddInfoItem( gi, "Dir", gSystem->WorkingDirectory());
1344 AddInfoItem( gi, "Training events", gTools().StringFromInt(Data()->GetNTrainingEvents()));
1345 AddInfoItem( gi, "TrainingTime", gTools().StringFromDouble(const_cast<TMVA::MethodBase*>(this)->GetTrainTime()));
1346
1347 Types::EAnalysisType aType = const_cast<TMVA::MethodBase*>(this)->GetAnalysisType();
1348 TString analysisType((aType==Types::kRegression) ? "Regression" :
1349 (aType==Types::kMulticlass ? "Multiclass" : "Classification"));
1350 AddInfoItem( gi, "AnalysisType", analysisType );
1351 delete userInfo;
1352
1353 // write options
1354 AddOptionsXMLTo( parent );
1355
1356 // write variable info
1357 AddVarsXMLTo( parent );
1358
1359 // write spectator info
1360 if (fModelPersistence)
1361 AddSpectatorsXMLTo( parent );
1362
1363 // write class info if in multiclass mode
1364 AddClassesXMLTo(parent);
1365
1366 // write target info if in regression mode
1367 if (DoRegression()) AddTargetsXMLTo(parent);
1368
1369 // write transformations
1370 GetTransformationHandler(false).AddXMLTo( parent );
1371
1372 // write MVA variable distributions
1373 void* pdfs = gTools().AddChild(parent, "MVAPdfs");
1374 if (fMVAPdfS) fMVAPdfS->AddXMLTo(pdfs);
1375 if (fMVAPdfB) fMVAPdfB->AddXMLTo(pdfs);
1376
1377 // write weights
1378 AddWeightsXMLTo( parent );
1379}
1380
1381////////////////////////////////////////////////////////////////////////////////
1382/// write reference MVA distributions (and other information)
1383/// to a ROOT type weight file
1384
1386{
1387 Bool_t addDirStatus = TH1::AddDirectoryStatus();
1388 TH1::AddDirectory( 0 ); // this avoids the binding of the hists in PDF to the current ROOT file
1389 fMVAPdfS = (TMVA::PDF*)rf.Get( "MVA_PDF_Signal" );
1390 fMVAPdfB = (TMVA::PDF*)rf.Get( "MVA_PDF_Background" );
1391
1392 TH1::AddDirectory( addDirStatus );
1393
1394 ReadWeightsFromStream( rf );
1395
1396 SetTestvarName();
1397}
1398
1399////////////////////////////////////////////////////////////////////////////////
1400/// write options and weights to file
1401/// note that each one text file for the main configuration information
1402/// and one ROOT file for ROOT objects are created
1403
1405{
1406 // ---- create the text file
1407 TString tfname( GetWeightFileName() );
1408
1409 // writing xml file
1410 TString xmlfname( tfname ); xmlfname.ReplaceAll( ".txt", ".xml" );
1411 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1412 << "Creating xml weight file: "
1413 << gTools().Color("lightblue") << xmlfname << gTools().Color("reset") << Endl;
1414 void* doc = gTools().xmlengine().NewDoc();
1415 void* rootnode = gTools().AddChild(0,"MethodSetup", "", true);
1416 gTools().xmlengine().DocSetRootElement(doc,rootnode);
1417 gTools().AddAttr(rootnode,"Method", GetMethodTypeName() + "::" + GetMethodName());
1418 WriteStateToXML(rootnode);
1419 gTools().xmlengine().SaveDoc(doc,xmlfname);
1420 gTools().xmlengine().FreeDoc(doc);
1421}
1422
1423////////////////////////////////////////////////////////////////////////////////
1424/// Function to write options and weights to file
1425
1427{
1428 // get the filename
1429
1430 TString tfname(GetWeightFileName());
1431
1432 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1433 << "Reading weight file: "
1434 << gTools().Color("lightblue") << tfname << gTools().Color("reset") << Endl;
1435
1436 if (tfname.EndsWith(".xml") ) {
1437 void* doc = gTools().xmlengine().ParseFile(tfname,gTools().xmlenginebuffersize()); // the default buffer size in TXMLEngine::ParseFile is 100k. Starting with ROOT 5.29 one can set the buffer size, see: http://savannah.cern.ch/bugs/?78864. This might be necessary for large XML files
1438 if (!doc) {
1439 Log() << kFATAL << "Error parsing XML file " << tfname << Endl;
1440 }
1441 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1442 ReadStateFromXML(rootnode);
1443 gTools().xmlengine().FreeDoc(doc);
1444 }
1445 else {
1446 std::filebuf fb;
1447 fb.open(tfname.Data(),std::ios::in);
1448 if (!fb.is_open()) { // file not found --> Error
1449 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ReadStateFromFile> "
1450 << "Unable to open input weight file: " << tfname << Endl;
1451 }
1452 std::istream fin(&fb);
1453 ReadStateFromStream(fin);
1454 fb.close();
1455 }
1456 if (!fTxtWeightsOnly) {
1457 // ---- read the ROOT file
1458 TString rfname( tfname ); rfname.ReplaceAll( ".txt", ".root" );
1459 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Reading root weight file: "
1460 << gTools().Color("lightblue") << rfname << gTools().Color("reset") << Endl;
1461 TFile* rfile = TFile::Open( rfname, "READ" );
1462 ReadStateFromStream( *rfile );
1463 rfile->Close();
1464 }
1465}
1466////////////////////////////////////////////////////////////////////////////////
1467/// for reading from memory
1468
1470 void* doc = gTools().xmlengine().ParseString(xmlstr);
1471 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1472 ReadStateFromXML(rootnode);
1473 gTools().xmlengine().FreeDoc(doc);
1474
1475 return;
1476}
1477
1478////////////////////////////////////////////////////////////////////////////////
1479
1481{
1482
1483 TString fullMethodName;
1484 gTools().ReadAttr( methodNode, "Method", fullMethodName );
1485
1486 fMethodName = fullMethodName(fullMethodName.Index("::")+2,fullMethodName.Length());
1487
1488 // update logger
1489 Log().SetSource( GetName() );
1490 Log() << kDEBUG//<<Form("Dataset[%s] : ",DataInfo().GetName())
1491 << "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1492
1493 // after the method name is read, the testvar can be set
1494 SetTestvarName();
1495
1496 TString nodeName("");
1497 void* ch = gTools().GetChild(methodNode);
1498 while (ch!=0) {
1499 nodeName = TString( gTools().GetName(ch) );
1500
1501 if (nodeName=="GeneralInfo") {
1502 // read analysis type
1503
1504 TString name(""),val("");
1505 void* antypeNode = gTools().GetChild(ch);
1506 while (antypeNode) {
1507 gTools().ReadAttr( antypeNode, "name", name );
1508
1509 if (name == "TrainingTime")
1510 gTools().ReadAttr( antypeNode, "value", fTrainTime );
1511
1512 if (name == "AnalysisType") {
1513 gTools().ReadAttr( antypeNode, "value", val );
1514 val.ToLower();
1515 if (val == "regression" ) SetAnalysisType( Types::kRegression );
1516 else if (val == "classification" ) SetAnalysisType( Types::kClassification );
1517 else if (val == "multiclass" ) SetAnalysisType( Types::kMulticlass );
1518 else Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Analysis type " << val << " is not known." << Endl;
1519 }
1520
1521 if (name == "TMVA Release" || name == "TMVA") {
1522 TString s;
1523 gTools().ReadAttr( antypeNode, "value", s);
1524 fTMVATrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1525 Log() << kDEBUG <<Form("[%s] : ",DataInfo().GetName()) << "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
1526 }
1527
1528 if (name == "ROOT Release" || name == "ROOT") {
1529 TString s;
1530 gTools().ReadAttr( antypeNode, "value", s);
1531 fROOTTrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1532 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
1533 << "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
1534 }
1535 antypeNode = gTools().GetNextChild(antypeNode);
1536 }
1537 }
1538 else if (nodeName=="Options") {
1539 ReadOptionsFromXML(ch);
1540 ParseOptions();
1541
1542 }
1543 else if (nodeName=="Variables") {
1544 ReadVariablesFromXML(ch);
1545 }
1546 else if (nodeName=="Spectators") {
1547 ReadSpectatorsFromXML(ch);
1548 }
1549 else if (nodeName=="Classes") {
1550 if (DataInfo().GetNClasses()==0) ReadClassesFromXML(ch);
1551 }
1552 else if (nodeName=="Targets") {
1553 if (DataInfo().GetNTargets()==0 && DoRegression()) ReadTargetsFromXML(ch);
1554 }
1555 else if (nodeName=="Transformations") {
1556 GetTransformationHandler().ReadFromXML(ch);
1557 }
1558 else if (nodeName=="MVAPdfs") {
1559 TString pdfname;
1560 if (fMVAPdfS) { delete fMVAPdfS; fMVAPdfS=0; }
1561 if (fMVAPdfB) { delete fMVAPdfB; fMVAPdfB=0; }
1562 void* pdfnode = gTools().GetChild(ch);
1563 if (pdfnode) {
1564 gTools().ReadAttr(pdfnode, "Name", pdfname);
1565 fMVAPdfS = new PDF(pdfname);
1566 fMVAPdfS->ReadXML(pdfnode);
1567 pdfnode = gTools().GetNextChild(pdfnode);
1568 gTools().ReadAttr(pdfnode, "Name", pdfname);
1569 fMVAPdfB = new PDF(pdfname);
1570 fMVAPdfB->ReadXML(pdfnode);
1571 }
1572 }
1573 else if (nodeName=="Weights") {
1574 ReadWeightsFromXML(ch);
1575 }
1576 else {
1577 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Unparsed XML node: '" << nodeName << "'" << Endl;
1578 }
1579 ch = gTools().GetNextChild(ch);
1580
1581 }
1582
1583 // update transformation handler
1584 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1585}
1586
1587////////////////////////////////////////////////////////////////////////////////
1588/// read the header from the weight files of the different MVA methods
1589
1591{
1592 char buf[512];
1593
1594 // when reading from stream, we assume the files are produced with TMVA<=397
1595 SetAnalysisType(Types::kClassification);
1596
1597
1598 // first read the method name
1599 GetLine(fin,buf);
1600 while (!TString(buf).BeginsWith("Method")) GetLine(fin,buf);
1601 TString namestr(buf);
1602
1603 TString methodType = namestr(0,namestr.Index("::"));
1604 methodType = methodType(methodType.Last(' '),methodType.Length());
1605 methodType = methodType.Strip(TString::kLeading);
1606
1607 TString methodName = namestr(namestr.Index("::")+2,namestr.Length());
1608 methodName = methodName.Strip(TString::kLeading);
1609 if (methodName == "") methodName = methodType;
1610 fMethodName = methodName;
1611
1612 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1613
1614 // update logger
1615 Log().SetSource( GetName() );
1616
1617 // now the question is whether to read the variables first or the options (well, of course the order
1618 // of writing them needs to agree)
1619 //
1620 // the option "Decorrelation" is needed to decide if the variables we
1621 // read are decorrelated or not
1622 //
1623 // the variables are needed by some methods (TMLP) to build the NN
1624 // which is done in ProcessOptions so for the time being we first Read and Parse the options then
1625 // we read the variables, and then we process the options
1626
1627 // now read all options
1628 GetLine(fin,buf);
1629 while (!TString(buf).BeginsWith("#OPT")) GetLine(fin,buf);
1630 ReadOptionsFromStream(fin);
1631 ParseOptions();
1632
1633 // Now read variable info
1634 fin.getline(buf,512);
1635 while (!TString(buf).BeginsWith("#VAR")) fin.getline(buf,512);
1636 ReadVarsFromStream(fin);
1637
1638 // now we process the options (of the derived class)
1639 ProcessOptions();
1640
1641 if (IsNormalised()) {
1643 GetTransformationHandler().AddTransformation( new VariableNormalizeTransform(DataInfo()), -1 );
1644 norm->BuildTransformationFromVarInfo( DataInfo().GetVariableInfos() );
1645 }
1646 VariableTransformBase *varTrafo(0), *varTrafo2(0);
1647 if ( fVarTransformString == "None") {
1648 if (fUseDecorr)
1649 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1650 } else if ( fVarTransformString == "Decorrelate" ) {
1651 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1652 } else if ( fVarTransformString == "PCA" ) {
1653 varTrafo = GetTransformationHandler().AddTransformation( new VariablePCATransform(DataInfo()), -1 );
1654 } else if ( fVarTransformString == "Uniform" ) {
1655 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo(),"Uniform"), -1 );
1656 } else if ( fVarTransformString == "Gauss" ) {
1657 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1658 } else if ( fVarTransformString == "GaussDecorr" ) {
1659 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1660 varTrafo2 = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1661 } else {
1662 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ProcessOptions> Variable transform '"
1663 << fVarTransformString << "' unknown." << Endl;
1664 }
1665 // Now read decorrelation matrix if available
1666 if (GetTransformationHandler().GetTransformationList().GetSize() > 0) {
1667 fin.getline(buf,512);
1668 while (!TString(buf).BeginsWith("#MAT")) fin.getline(buf,512);
1669 if (varTrafo) {
1670 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1671 varTrafo->ReadTransformationFromStream(fin, trafo );
1672 }
1673 if (varTrafo2) {
1674 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1675 varTrafo2->ReadTransformationFromStream(fin, trafo );
1676 }
1677 }
1678
1679
1680 if (HasMVAPdfs()) {
1681 // Now read the MVA PDFs
1682 fin.getline(buf,512);
1683 while (!TString(buf).BeginsWith("#MVAPDFS")) fin.getline(buf,512);
1684 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
1685 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
1686 fMVAPdfS = new PDF(TString(GetName()) + " MVA PDF Sig");
1687 fMVAPdfB = new PDF(TString(GetName()) + " MVA PDF Bkg");
1688 fMVAPdfS->SetReadingVersion( GetTrainingTMVAVersionCode() );
1689 fMVAPdfB->SetReadingVersion( GetTrainingTMVAVersionCode() );
1690
1691 fin >> *fMVAPdfS;
1692 fin >> *fMVAPdfB;
1693 }
1694
1695 // Now read weights
1696 fin.getline(buf,512);
1697 while (!TString(buf).BeginsWith("#WGT")) fin.getline(buf,512);
1698 fin.getline(buf,512);
1699 ReadWeightsFromStream( fin );;
1700
1701 // update transformation handler
1702 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1703
1704}
1705
1706////////////////////////////////////////////////////////////////////////////////
1707/// write the list of variables (name, min, max) for a given data
1708/// transformation method to the stream
1709
1710void TMVA::MethodBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const
1711{
1712 o << prefix << "NVar " << DataInfo().GetNVariables() << std::endl;
1713 std::vector<VariableInfo>::const_iterator varIt = DataInfo().GetVariableInfos().begin();
1714 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1715 o << prefix << "NSpec " << DataInfo().GetNSpectators() << std::endl;
1716 varIt = DataInfo().GetSpectatorInfos().begin();
1717 for (; varIt!=DataInfo().GetSpectatorInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1718}
1719
1720////////////////////////////////////////////////////////////////////////////////
1721/// Read the variables (name, min, max) for a given data
1722/// transformation method from the stream. In the stream we only
1723/// expect the limits which will be set
1724
1726{
1727 TString dummy;
1728 UInt_t readNVar;
1729 istr >> dummy >> readNVar;
1730
1731 if (readNVar!=DataInfo().GetNVariables()) {
1732 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1733 << " while there are " << readNVar << " variables declared in the file"
1734 << Endl;
1735 }
1736
1737 // we want to make sure all variables are read in the order they are defined
1738 VariableInfo varInfo;
1739 std::vector<VariableInfo>::iterator varIt = DataInfo().GetVariableInfos().begin();
1740 int varIdx = 0;
1741 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt, ++varIdx) {
1742 varInfo.ReadFromStream(istr);
1743 if (varIt->GetExpression() == varInfo.GetExpression()) {
1744 varInfo.SetExternalLink((*varIt).GetExternalLink());
1745 (*varIt) = varInfo;
1746 }
1747 else {
1748 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVarsFromStream>" << Endl;
1749 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1750 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
1751 Log() << kINFO << "the correct working of the method):" << Endl;
1752 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
1753 Log() << kINFO << " var #" << varIdx <<" declared in file : " << varInfo.GetExpression() << Endl;
1754 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1755 }
1756 }
1757}
1758
1759////////////////////////////////////////////////////////////////////////////////
1760/// write variable info to XML
1761
1762void TMVA::MethodBase::AddVarsXMLTo( void* parent ) const
1763{
1764 void* vars = gTools().AddChild(parent, "Variables");
1765 gTools().AddAttr( vars, "NVar", gTools().StringFromInt(DataInfo().GetNVariables()) );
1766
1767 for (UInt_t idx=0; idx<DataInfo().GetVariableInfos().size(); idx++) {
1768 VariableInfo& vi = DataInfo().GetVariableInfos()[idx];
1769 void* var = gTools().AddChild( vars, "Variable" );
1770 gTools().AddAttr( var, "VarIndex", idx );
1771 vi.AddToXML( var );
1772 }
1773}
1774
1775////////////////////////////////////////////////////////////////////////////////
1776/// write spectator info to XML
1777
1779{
1780 void* specs = gTools().AddChild(parent, "Spectators");
1781
1782 UInt_t writeIdx=0;
1783 for (UInt_t idx=0; idx<DataInfo().GetSpectatorInfos().size(); idx++) {
1784
1785 VariableInfo& vi = DataInfo().GetSpectatorInfos()[idx];
1786
1787 // we do not want to write spectators that are category-cuts,
1788 // except if the method is the category method and the spectators belong to it
1789 if (vi.GetVarType()=='C') continue;
1790
1791 void* spec = gTools().AddChild( specs, "Spectator" );
1792 gTools().AddAttr( spec, "SpecIndex", writeIdx++ );
1793 vi.AddToXML( spec );
1794 }
1795 gTools().AddAttr( specs, "NSpec", gTools().StringFromInt(writeIdx) );
1796}
1797
1798////////////////////////////////////////////////////////////////////////////////
1799/// write class info to XML
1800
1801void TMVA::MethodBase::AddClassesXMLTo( void* parent ) const
1802{
1803 UInt_t nClasses=DataInfo().GetNClasses();
1804
1805 void* classes = gTools().AddChild(parent, "Classes");
1806 gTools().AddAttr( classes, "NClass", nClasses );
1807
1808 for (UInt_t iCls=0; iCls<nClasses; ++iCls) {
1809 ClassInfo *classInfo=DataInfo().GetClassInfo (iCls);
1810 TString className =classInfo->GetName();
1811 UInt_t classNumber=classInfo->GetNumber();
1812
1813 void* classNode=gTools().AddChild(classes, "Class");
1814 gTools().AddAttr( classNode, "Name", className );
1815 gTools().AddAttr( classNode, "Index", classNumber );
1816 }
1817}
1818////////////////////////////////////////////////////////////////////////////////
1819/// write target info to XML
1820
1821void TMVA::MethodBase::AddTargetsXMLTo( void* parent ) const
1822{
1823 void* targets = gTools().AddChild(parent, "Targets");
1824 gTools().AddAttr( targets, "NTrgt", gTools().StringFromInt(DataInfo().GetNTargets()) );
1825
1826 for (UInt_t idx=0; idx<DataInfo().GetTargetInfos().size(); idx++) {
1827 VariableInfo& vi = DataInfo().GetTargetInfos()[idx];
1828 void* tar = gTools().AddChild( targets, "Target" );
1829 gTools().AddAttr( tar, "TargetIndex", idx );
1830 vi.AddToXML( tar );
1831 }
1832}
1833
1834////////////////////////////////////////////////////////////////////////////////
1835/// read variable info from XML
1836
1838{
1839 UInt_t readNVar;
1840 gTools().ReadAttr( varnode, "NVar", readNVar);
1841
1842 if (readNVar!=DataInfo().GetNVariables()) {
1843 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1844 << " while there are " << readNVar << " variables declared in the file"
1845 << Endl;
1846 }
1847
1848 // we want to make sure all variables are read in the order they are defined
1849 VariableInfo readVarInfo, existingVarInfo;
1850 int varIdx = 0;
1851 void* ch = gTools().GetChild(varnode);
1852 while (ch) {
1853 gTools().ReadAttr( ch, "VarIndex", varIdx);
1854 existingVarInfo = DataInfo().GetVariableInfos()[varIdx];
1855 readVarInfo.ReadFromXML(ch);
1856
1857 if (existingVarInfo.GetExpression() == readVarInfo.GetExpression()) {
1858 readVarInfo.SetExternalLink(existingVarInfo.GetExternalLink());
1859 existingVarInfo = readVarInfo;
1860 }
1861 else {
1862 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVariablesFromXML>" << Endl;
1863 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1864 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1865 Log() << kINFO << "correct working of the method):" << Endl;
1866 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << existingVarInfo.GetExpression() << Endl;
1867 Log() << kINFO << " var #" << varIdx <<" declared in file : " << readVarInfo.GetExpression() << Endl;
1868 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1869 }
1870 ch = gTools().GetNextChild(ch);
1871 }
1872}
1873
1874////////////////////////////////////////////////////////////////////////////////
1875/// read spectator info from XML
1876
1878{
1879 UInt_t readNSpec;
1880 gTools().ReadAttr( specnode, "NSpec", readNSpec);
1881
1882 if (readNSpec!=DataInfo().GetNSpectators(kFALSE)) {
1883 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "You declared "<< DataInfo().GetNSpectators(kFALSE) << " spectators in the Reader"
1884 << " while there are " << readNSpec << " spectators declared in the file"
1885 << Endl;
1886 }
1887
1888 // we want to make sure all variables are read in the order they are defined
1889 VariableInfo readSpecInfo, existingSpecInfo;
1890 int specIdx = 0;
1891 void* ch = gTools().GetChild(specnode);
1892 while (ch) {
1893 gTools().ReadAttr( ch, "SpecIndex", specIdx);
1894 existingSpecInfo = DataInfo().GetSpectatorInfos()[specIdx];
1895 readSpecInfo.ReadFromXML(ch);
1896
1897 if (existingSpecInfo.GetExpression() == readSpecInfo.GetExpression()) {
1898 readSpecInfo.SetExternalLink(existingSpecInfo.GetExternalLink());
1899 existingSpecInfo = readSpecInfo;
1900 }
1901 else {
1902 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadSpectatorsFromXML>" << Endl;
1903 Log() << kINFO << "The definition (or the order) of the spectators found in the input file is" << Endl;
1904 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1905 Log() << kINFO << "correct working of the method):" << Endl;
1906 Log() << kINFO << " spec #" << specIdx <<" declared in Reader: " << existingSpecInfo.GetExpression() << Endl;
1907 Log() << kINFO << " spec #" << specIdx <<" declared in file : " << readSpecInfo.GetExpression() << Endl;
1908 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1909 }
1910 ch = gTools().GetNextChild(ch);
1911 }
1912}
1913
1914////////////////////////////////////////////////////////////////////////////////
1915/// read number of classes from XML
1916
1918{
1919 UInt_t readNCls;
1920 // coverity[tainted_data_argument]
1921 gTools().ReadAttr( clsnode, "NClass", readNCls);
1922
1923 TString className="";
1924 UInt_t classIndex=0;
1925 void* ch = gTools().GetChild(clsnode);
1926 if (!ch) {
1927 for (UInt_t icls = 0; icls<readNCls;++icls) {
1928 TString classname = TString::Format("class%i",icls);
1929 DataInfo().AddClass(classname);
1930
1931 }
1932 }
1933 else{
1934 while (ch) {
1935 gTools().ReadAttr( ch, "Index", classIndex);
1936 gTools().ReadAttr( ch, "Name", className );
1937 DataInfo().AddClass(className);
1938
1939 ch = gTools().GetNextChild(ch);
1940 }
1941 }
1942
1943 // retrieve signal and background class index
1944 if (DataInfo().GetClassInfo("Signal") != 0) {
1945 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
1946 }
1947 else
1948 fSignalClass=0;
1949 if (DataInfo().GetClassInfo("Background") != 0) {
1950 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
1951 }
1952 else
1953 fBackgroundClass=1;
1954}
1955
1956////////////////////////////////////////////////////////////////////////////////
1957/// read target info from XML
1958
1960{
1961 UInt_t readNTar;
1962 gTools().ReadAttr( tarnode, "NTrgt", readNTar);
1963
1964 int tarIdx = 0;
1965 TString expression;
1966 void* ch = gTools().GetChild(tarnode);
1967 while (ch) {
1968 gTools().ReadAttr( ch, "TargetIndex", tarIdx);
1969 gTools().ReadAttr( ch, "Expression", expression);
1970 DataInfo().AddTarget(expression,"","",0,0);
1971
1972 ch = gTools().GetNextChild(ch);
1973 }
1974}
1975
1976////////////////////////////////////////////////////////////////////////////////
1977/// returns the ROOT directory where info/histograms etc of the
1978/// corresponding MVA method instance are stored
1979
1981{
1982 if (fBaseDir != 0) return fBaseDir;
1983 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodName() << " not set yet --> check if already there.." <<Endl;
1984
1985 if (IsSilentFile()) {
1986 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
1987 << "MethodBase::BaseDir() - No directory exists when running a Method without output file. Enable the "
1988 "output when creating the factory"
1989 << Endl;
1990 }
1991
1992 TDirectory* methodDir = MethodBaseDir();
1993 if (methodDir==0)
1994 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MethodBase::BaseDir() - MethodBaseDir() return a NULL pointer!" << Endl;
1995
1996 TString defaultDir = GetMethodName();
1997 TDirectory *sdir = methodDir->GetDirectory(defaultDir.Data());
1998 if(!sdir)
1999 {
2000 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " does not exist yet--> created it" <<Endl;
2001 sdir = methodDir->mkdir(defaultDir);
2002 sdir->cd();
2003 // write weight file name into target file
2004 if (fModelPersistence) {
2005 TObjString wfilePath( gSystem->WorkingDirectory() );
2006 TObjString wfileName( GetWeightFileName() );
2007 wfilePath.Write( "TrainingPath" );
2008 wfileName.Write( "WeightFileName" );
2009 }
2010 }
2011
2012 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " existed, return it.." <<Endl;
2013 return sdir;
2014}
2015
2016////////////////////////////////////////////////////////////////////////////////
2017/// returns the ROOT directory where all instances of the
2018/// corresponding MVA method are stored
2019
2021{
2022 if (fMethodBaseDir != 0) {
2023 return fMethodBaseDir;
2024 }
2025
2026 const char *datasetName = DataInfo().GetName();
2027
2028 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodTypeName()
2029 << " not set yet --> check if already there.." << Endl;
2030
2031 TDirectory *factoryBaseDir = GetFile();
2032 if (!factoryBaseDir) return nullptr;
2033 fMethodBaseDir = factoryBaseDir->GetDirectory(datasetName);
2034 if (!fMethodBaseDir) {
2035 fMethodBaseDir = factoryBaseDir->mkdir(datasetName, TString::Format("Base directory for dataset %s", datasetName).Data());
2036 if (!fMethodBaseDir) {
2037 Log() << kFATAL << "Can not create dir " << datasetName;
2038 }
2039 }
2040 TString methodTypeDir = TString::Format("Method_%s", GetMethodTypeName().Data());
2041 fMethodBaseDir = fMethodBaseDir->GetDirectory(methodTypeDir.Data());
2042
2043 if (!fMethodBaseDir) {
2044 TDirectory *datasetDir = factoryBaseDir->GetDirectory(datasetName);
2045 TString methodTypeDirHelpStr = TString::Format("Directory for all %s methods", GetMethodTypeName().Data());
2046 fMethodBaseDir = datasetDir->mkdir(methodTypeDir.Data(), methodTypeDirHelpStr);
2047 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodName()
2048 << " does not exist yet--> created it" << Endl;
2049 }
2050
2051 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName)
2052 << "Return from MethodBaseDir() after creating base directory " << Endl;
2053 return fMethodBaseDir;
2054}
2055
2056////////////////////////////////////////////////////////////////////////////////
2057/// set directory of weight file
2058
2060{
2061 fFileDir = fileDir;
2062 gSystem->mkdir( fFileDir, kTRUE );
2063}
2064
2065////////////////////////////////////////////////////////////////////////////////
2066/// set the weight file name (depreciated)
2067
2069{
2070 fWeightFile = theWeightFile;
2071}
2072
2073////////////////////////////////////////////////////////////////////////////////
2074/// retrieve weight file name
2075
2077{
2078 if (fWeightFile!="") return fWeightFile;
2079
2080 // the default consists of
2081 // directory/jobname_methodname_suffix.extension.{root/txt}
2082 TString suffix = "";
2083 TString wFileDir(GetWeightFileDir());
2084 TString wFileName = GetJobName() + "_" + GetMethodName() +
2085 suffix + "." + gConfig().GetIONames().fWeightFileExtension + ".xml";
2086 if (wFileDir.IsNull() ) return wFileName;
2087 // add weight file directory of it is not null
2088 return ( wFileDir + (wFileDir[wFileDir.Length()-1]=='/' ? "" : "/")
2089 + wFileName );
2090}
2091////////////////////////////////////////////////////////////////////////////////
2092/// writes all MVA evaluation histograms to file
2093
2095{
2096 BaseDir()->cd();
2097
2098
2099 // write MVA PDFs to file - if exist
2100 if (0 != fMVAPdfS) {
2101 fMVAPdfS->GetOriginalHist()->Write();
2102 fMVAPdfS->GetSmoothedHist()->Write();
2103 fMVAPdfS->GetPDFHist()->Write();
2104 }
2105 if (0 != fMVAPdfB) {
2106 fMVAPdfB->GetOriginalHist()->Write();
2107 fMVAPdfB->GetSmoothedHist()->Write();
2108 fMVAPdfB->GetPDFHist()->Write();
2109 }
2110
2111 // write result-histograms
2112 Results* results = Data()->GetResults( GetMethodName(), treetype, Types::kMaxAnalysisType );
2113 if (!results)
2114 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<WriteEvaluationHistosToFile> Unknown result: "
2115 << GetMethodName() << (treetype==Types::kTraining?"/kTraining":"/kTesting")
2116 << "/kMaxAnalysisType" << Endl;
2117 results->GetStorage()->Write();
2118 if (treetype==Types::kTesting) {
2119 // skipping plotting of variables if too many (default is 200)
2120 if ((int) DataInfo().GetNVariables()< gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables)
2121 GetTransformationHandler().PlotVariables (GetEventCollection( Types::kTesting ), BaseDir() );
2122 else
2123 Log() << kINFO << TString::Format("Dataset[%s] : ",DataInfo().GetName())
2124 << " variable plots are not produces ! The number of variables is " << DataInfo().GetNVariables()
2125 << " , it is larger than " << gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables << Endl;
2126 }
2127}
2128
2129////////////////////////////////////////////////////////////////////////////////
2130/// write special monitoring histograms to file
2131/// dummy implementation here -----------------
2132
2134{
2135}
2136
2137////////////////////////////////////////////////////////////////////////////////
2138/// reads one line from the input stream
2139/// checks for certain keywords and interprets
2140/// the line if keywords are found
2141
2142Bool_t TMVA::MethodBase::GetLine(std::istream& fin, char* buf )
2143{
2144 fin.getline(buf,512);
2145 TString line(buf);
2146 if (line.BeginsWith("TMVA Release")) {
2147 Ssiz_t start = line.First('[')+1;
2148 Ssiz_t length = line.Index("]",start)-start;
2149 TString code = line(start,length);
2150 std::stringstream s(code.Data());
2151 s >> fTMVATrainingVersion;
2152 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
2153 }
2154 if (line.BeginsWith("ROOT Release")) {
2155 Ssiz_t start = line.First('[')+1;
2156 Ssiz_t length = line.Index("]",start)-start;
2157 TString code = line(start,length);
2158 std::stringstream s(code.Data());
2159 s >> fROOTTrainingVersion;
2160 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
2161 }
2162 if (line.BeginsWith("Analysis type")) {
2163 Ssiz_t start = line.First('[')+1;
2164 Ssiz_t length = line.Index("]",start)-start;
2165 TString code = line(start,length);
2166 std::stringstream s(code.Data());
2167 std::string analysisType;
2168 s >> analysisType;
2169 if (analysisType == "regression" || analysisType == "Regression") SetAnalysisType( Types::kRegression );
2170 else if (analysisType == "classification" || analysisType == "Classification") SetAnalysisType( Types::kClassification );
2171 else if (analysisType == "multiclass" || analysisType == "Multiclass") SetAnalysisType( Types::kMulticlass );
2172 else Log() << kFATAL << "Analysis type " << analysisType << " from weight-file not known!" << std::endl;
2173
2174 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Method was trained for "
2175 << (GetAnalysisType() == Types::kRegression ? "Regression" :
2176 (GetAnalysisType() == Types::kMulticlass ? "Multiclass" : "Classification")) << Endl;
2177 }
2178
2179 return true;
2180}
2181
2182////////////////////////////////////////////////////////////////////////////////
2183/// Create PDFs of the MVA output variables
2184
2186{
2187 Data()->SetCurrentType(Types::kTraining);
2188
2189 // the PDF's are stored as results ONLY if the corresponding "results" are booked,
2190 // otherwise they will be only used 'online'
2191 ResultsClassification * mvaRes = dynamic_cast<ResultsClassification*>
2192 ( Data()->GetResults(GetMethodName(), Types::kTraining, Types::kClassification) );
2193
2194 if (mvaRes==0 || mvaRes->GetSize()==0) {
2195 Log() << kERROR<<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CreateMVAPdfs> No result of classifier testing available" << Endl;
2196 }
2197
2198 Double_t minVal = *std::min_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2199 Double_t maxVal = *std::max_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2200
2201 // create histograms that serve as basis to create the MVA Pdfs
2202 TH1* histMVAPdfS = new TH1D( GetMethodTypeName() + "_tr_S", GetMethodTypeName() + "_tr_S",
2203 fMVAPdfS->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2204 TH1* histMVAPdfB = new TH1D( GetMethodTypeName() + "_tr_B", GetMethodTypeName() + "_tr_B",
2205 fMVAPdfB->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2206
2207
2208 // compute sum of weights properly
2209 histMVAPdfS->Sumw2();
2210 histMVAPdfB->Sumw2();
2211
2212 // fill histograms
2213 for (UInt_t ievt=0; ievt<mvaRes->GetSize(); ievt++) {
2214 Double_t theVal = mvaRes->GetValueVector()->at(ievt);
2215 Double_t theWeight = Data()->GetEvent(ievt)->GetWeight();
2216
2217 if (DataInfo().IsSignal(Data()->GetEvent(ievt))) histMVAPdfS->Fill( theVal, theWeight );
2218 else histMVAPdfB->Fill( theVal, theWeight );
2219 }
2220
2221 gTools().NormHist( histMVAPdfS );
2222 gTools().NormHist( histMVAPdfB );
2223
2224 // momentary hack for ROOT problem
2225 if(!IsSilentFile())
2226 {
2227 histMVAPdfS->Write();
2228 histMVAPdfB->Write();
2229 }
2230 // create PDFs
2231 fMVAPdfS->BuildPDF ( histMVAPdfS );
2232 fMVAPdfB->BuildPDF ( histMVAPdfB );
2233 fMVAPdfS->ValidatePDF( histMVAPdfS );
2234 fMVAPdfB->ValidatePDF( histMVAPdfB );
2235
2236 if (DataInfo().GetNClasses() == 2) { // TODO: this is an ugly hack.. adapt this to new framework
2237 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())
2238 << TString::Format( "<CreateMVAPdfs> Separation from histogram (PDF): %1.3f (%1.3f)",
2239 GetSeparation( histMVAPdfS, histMVAPdfB ), GetSeparation( fMVAPdfS, fMVAPdfB ) )
2240 << Endl;
2241 }
2242
2243 delete histMVAPdfS;
2244 delete histMVAPdfB;
2245}
2246
2248 // the simple one, automatically calculates the mvaVal and uses the
2249 // SAME sig/bkg ratio as given in the training sample (typically 50/50
2250 // .. (NormMode=EqualNumEvents) but can be different)
2251 if (!fMVAPdfS || !fMVAPdfB) {
2252 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<GetProba> MVA PDFs for Signal and Background don't exist yet, we'll create them on demand" << Endl;
2253 CreateMVAPdfs();
2254 }
2255 Double_t sigFraction = DataInfo().GetTrainingSumSignalWeights() / (DataInfo().GetTrainingSumSignalWeights() + DataInfo().GetTrainingSumBackgrWeights() );
2256 Double_t mvaVal = GetMvaValue(ev);
2257
2258 return GetProba(mvaVal,sigFraction);
2259
2260}
2261////////////////////////////////////////////////////////////////////////////////
2262/// compute likelihood ratio
2263
2265{
2266 if (!fMVAPdfS || !fMVAPdfB) {
2267 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetProba> MVA PDFs for Signal and Background don't exist" << Endl;
2268 return -1.0;
2269 }
2270 Double_t p_s = fMVAPdfS->GetVal( mvaVal );
2271 Double_t p_b = fMVAPdfB->GetVal( mvaVal );
2272
2273 Double_t denom = p_s*ap_sig + p_b*(1 - ap_sig);
2274
2275 return (denom > 0) ? (p_s*ap_sig) / denom : -1;
2276}
2277
2278////////////////////////////////////////////////////////////////////////////////
2279/// compute rarity:
2280/// \f[
2281/// R(x) = \int_{[-\infty..x]} { PDF(x') dx' }
2282/// \f]
2283/// where PDF(x) is the PDF of the classifier's signal or background distribution
2284
2286{
2287 if ((reftype == Types::kSignal && !fMVAPdfS) || (reftype == Types::kBackground && !fMVAPdfB)) {
2288 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetRarity> Required MVA PDF for Signal or Background does not exist: "
2289 << "select option \"CreateMVAPdfs\"" << Endl;
2290 return 0.0;
2291 }
2292
2293 PDF* thePdf = ((reftype == Types::kSignal) ? fMVAPdfS : fMVAPdfB);
2294
2295 return thePdf->GetIntegral( thePdf->GetXmin(), mvaVal );
2296}
2297
2298////////////////////////////////////////////////////////////////////////////////
2299/// fill background efficiency (resp. rejection) versus signal efficiency plots
2300/// returns signal efficiency at background efficiency indicated in theString
2301
2303{
2304 Data()->SetCurrentType(type);
2305 Results* results = Data()->GetResults( GetMethodName(), type, Types::kClassification );
2306 std::vector<Float_t>* mvaRes = dynamic_cast<ResultsClassification*>(results)->GetValueVector();
2307
2308 // parse input string for required background efficiency
2309 TList* list = gTools().ParseFormatLine( theString );
2310
2311 // sanity check
2312 Bool_t computeArea = kFALSE;
2313 if (!list || list->GetSize() < 2) computeArea = kTRUE; // the area is computed
2314 else if (list->GetSize() > 2) {
2315 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Wrong number of arguments"
2316 << " in string: " << theString
2317 << " | required format, e.g., Efficiency:0.05, or empty string" << Endl;
2318 delete list;
2319 return -1;
2320 }
2321
2322 // sanity check
2323 if ( results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2324 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2325 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Binning mismatch between signal and background histos" << Endl;
2326 delete list;
2327 return -1.0;
2328 }
2329
2330 // create histograms
2331
2332 // first, get efficiency histograms for signal and background
2333 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2334 Double_t xmin = effhist->GetXaxis()->GetXmin();
2335 Double_t xmax = effhist->GetXaxis()->GetXmax();
2336
2337 TTHREAD_TLS(Double_t) nevtS;
2338
2339 // first round ? --> create histograms
2340 if (results->DoesExist("MVA_EFF_S")==0) {
2341
2342 // for efficiency plot
2343 TH1* eff_s = new TH1D( GetTestvarName() + "_effS", GetTestvarName() + " (signal)", fNbinsH, xmin, xmax );
2344 TH1* eff_b = new TH1D( GetTestvarName() + "_effB", GetTestvarName() + " (background)", fNbinsH, xmin, xmax );
2345 results->Store(eff_s, "MVA_EFF_S");
2346 results->Store(eff_b, "MVA_EFF_B");
2347
2348 // sign if cut
2349 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2350
2351 // this method is unbinned
2352 nevtS = 0;
2353 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2354
2355 // read the tree
2356 Bool_t isSignal = DataInfo().IsSignal(GetEvent(ievt));
2357 Float_t theWeight = GetEvent(ievt)->GetWeight();
2358 Float_t theVal = (*mvaRes)[ievt];
2359
2360 // select histogram depending on if sig or bgd
2361 TH1* theHist = isSignal ? eff_s : eff_b;
2362
2363 // count signal and background events in tree
2364 if (isSignal) nevtS+=theWeight;
2365
2366 TAxis* axis = theHist->GetXaxis();
2367 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2368 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2369 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2370 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2371 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2372
2373 if (sign > 0)
2374 for (Int_t ibin=1; ibin<=maxbin; ibin++) theHist->AddBinContent( ibin , theWeight);
2375 else if (sign < 0)
2376 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theHist->AddBinContent( ibin , theWeight );
2377 else
2378 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Mismatch in sign" << Endl;
2379 }
2380
2381 // renormalise maximum to <=1
2382 // eff_s->Scale( 1.0/TMath::Max(1.,eff_s->GetMaximum()) );
2383 // eff_b->Scale( 1.0/TMath::Max(1.,eff_b->GetMaximum()) );
2384
2385 eff_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(),eff_s->GetMaximum()) );
2386 eff_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(),eff_b->GetMaximum()) );
2387
2388 // background efficiency versus signal efficiency
2389 TH1* eff_BvsS = new TH1D( GetTestvarName() + "_effBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2390 results->Store(eff_BvsS, "MVA_EFF_BvsS");
2391 eff_BvsS->SetXTitle( "Signal eff" );
2392 eff_BvsS->SetYTitle( "Backgr eff" );
2393
2394 // background rejection (=1-eff.) versus signal efficiency
2395 TH1* rej_BvsS = new TH1D( GetTestvarName() + "_rejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2396 results->Store(rej_BvsS);
2397 rej_BvsS->SetXTitle( "Signal eff" );
2398 rej_BvsS->SetYTitle( "Backgr rejection (1-eff)" );
2399
2400 // inverse background eff (1/eff.) versus signal efficiency
2401 TH1* inveff_BvsS = new TH1D( GetTestvarName() + "_invBeffvsSeff",
2402 GetTestvarName(), fNbins, 0, 1 );
2403 results->Store(inveff_BvsS);
2404 inveff_BvsS->SetXTitle( "Signal eff" );
2405 inveff_BvsS->SetYTitle( "Inverse backgr. eff (1/eff)" );
2406
2407 // use root finder
2408 // spline background efficiency plot
2409 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2411 fSplRefS = new TSpline1( "spline2_signal", new TGraph( eff_s ) );
2412 fSplRefB = new TSpline1( "spline2_background", new TGraph( eff_b ) );
2413
2414 // verify spline sanity
2415 gTools().CheckSplines( eff_s, fSplRefS );
2416 gTools().CheckSplines( eff_b, fSplRefB );
2417 }
2418
2419 // make the background-vs-signal efficiency plot
2420
2421 // create root finder
2422 RootFinder rootFinder( this, fXmin, fXmax );
2423
2424 Double_t effB = 0;
2425 fEffS = eff_s; // to be set for the root finder
2426 for (Int_t bini=1; bini<=fNbins; bini++) {
2427
2428 // find cut value corresponding to a given signal efficiency
2429 Double_t effS = eff_BvsS->GetBinCenter( bini );
2430 Double_t cut = rootFinder.Root( effS );
2431
2432 // retrieve background efficiency for given cut
2433 if (Use_Splines_for_Eff_) effB = fSplRefB->Eval( cut );
2434 else effB = eff_b->GetBinContent( eff_b->FindBin( cut ) );
2435
2436 // and fill histograms
2437 eff_BvsS->SetBinContent( bini, effB );
2438 rej_BvsS->SetBinContent( bini, 1.0-effB );
2439 if (effB>std::numeric_limits<double>::epsilon())
2440 inveff_BvsS->SetBinContent( bini, 1.0/effB );
2441 }
2442
2443 // create splines for histogram
2444 fSpleffBvsS = new TSpline1( "effBvsS", new TGraph( eff_BvsS ) );
2445
2446 // search for overlap point where, when cutting on it,
2447 // one would obtain: eff_S = rej_B = 1 - eff_B
2448 Double_t effS = 0., rejB, effS_ = 0., rejB_ = 0.;
2449 Int_t nbins_ = 5000;
2450 for (Int_t bini=1; bini<=nbins_; bini++) {
2451
2452 // get corresponding signal and background efficiencies
2453 effS = (bini - 0.5)/Float_t(nbins_);
2454 rejB = 1.0 - fSpleffBvsS->Eval( effS );
2455
2456 // find signal efficiency that corresponds to required background efficiency
2457 if ((effS - rejB)*(effS_ - rejB_) < 0) break;
2458 effS_ = effS;
2459 rejB_ = rejB;
2460 }
2461
2462 // find cut that corresponds to signal efficiency and update signal-like criterion
2463 Double_t cut = rootFinder.Root( 0.5*(effS + effS_) );
2464 SetSignalReferenceCut( cut );
2465 fEffS = 0;
2466 }
2467
2468 // must exist...
2469 if (0 == fSpleffBvsS) {
2470 delete list;
2471 return 0.0;
2472 }
2473
2474 // now find signal efficiency that corresponds to required background efficiency
2475 Double_t effS = 0, effB = 0, effS_ = 0, effB_ = 0;
2476 Int_t nbins_ = 1000;
2477
2478 if (computeArea) {
2479
2480 // compute area of rej-vs-eff plot
2481 Double_t integral = 0;
2482 for (Int_t bini=1; bini<=nbins_; bini++) {
2483
2484 // get corresponding signal and background efficiencies
2485 effS = (bini - 0.5)/Float_t(nbins_);
2486 effB = fSpleffBvsS->Eval( effS );
2487 integral += (1.0 - effB);
2488 }
2489 integral /= nbins_;
2490
2491 delete list;
2492 return integral;
2493 }
2494 else {
2495
2496 // that will be the value of the efficiency retured (does not affect
2497 // the efficiency-vs-bkg plot which is done anyway.
2498 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2499
2500 // find precise efficiency value
2501 for (Int_t bini=1; bini<=nbins_; bini++) {
2502
2503 // get corresponding signal and background efficiencies
2504 effS = (bini - 0.5)/Float_t(nbins_);
2505 effB = fSpleffBvsS->Eval( effS );
2506
2507 // find signal efficiency that corresponds to required background efficiency
2508 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2509 effS_ = effS;
2510 effB_ = effB;
2511 }
2512
2513 // take mean between bin above and bin below
2514 effS = 0.5*(effS + effS_);
2515
2516 effSerr = 0;
2517 if (nevtS > 0) effSerr = TMath::Sqrt( effS*(1.0 - effS)/nevtS );
2518
2519 delete list;
2520 return effS;
2521 }
2522
2523 return -1;
2524}
2525
2526////////////////////////////////////////////////////////////////////////////////
2527
2529{
2530 Data()->SetCurrentType(Types::kTraining);
2531
2532 Results* results = Data()->GetResults(GetMethodName(), Types::kTesting, Types::kNoAnalysisType);
2533
2534 // fill background efficiency (resp. rejection) versus signal efficiency plots
2535 // returns signal efficiency at background efficiency indicated in theString
2536
2537 // parse input string for required background efficiency
2538 TList* list = gTools().ParseFormatLine( theString );
2539 // sanity check
2540
2541 if (list->GetSize() != 2) {
2542 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Wrong number of arguments"
2543 << " in string: " << theString
2544 << " | required format, e.g., Efficiency:0.05" << Endl;
2545 delete list;
2546 return -1;
2547 }
2548 // that will be the value of the efficiency retured (does not affect
2549 // the efficiency-vs-bkg plot which is done anyway.
2550 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2551
2552 delete list;
2553
2554 // sanity check
2555 if (results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2556 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2557 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Binning mismatch between signal and background histos"
2558 << Endl;
2559 return -1.0;
2560 }
2561
2562 // create histogram
2563
2564 // first, get efficiency histograms for signal and background
2565 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2566 Double_t xmin = effhist->GetXaxis()->GetXmin();
2567 Double_t xmax = effhist->GetXaxis()->GetXmax();
2568
2569 // first round ? --> create and fill histograms
2570 if (results->DoesExist("MVA_TRAIN_S")==0) {
2571
2572 // classifier response distributions for test sample
2573 Double_t sxmax = fXmax+0.00001;
2574
2575 // MVA plots on the training sample (check for overtraining)
2576 TH1* mva_s_tr = new TH1D( GetTestvarName() + "_Train_S",GetTestvarName() + "_Train_S", fNbinsMVAoutput, fXmin, sxmax );
2577 TH1* mva_b_tr = new TH1D( GetTestvarName() + "_Train_B",GetTestvarName() + "_Train_B", fNbinsMVAoutput, fXmin, sxmax );
2578 results->Store(mva_s_tr, "MVA_TRAIN_S");
2579 results->Store(mva_b_tr, "MVA_TRAIN_B");
2580 mva_s_tr->Sumw2();
2581 mva_b_tr->Sumw2();
2582
2583 // Training efficiency plots
2584 TH1* mva_eff_tr_s = new TH1D( GetTestvarName() + "_trainingEffS", GetTestvarName() + " (signal)",
2585 fNbinsH, xmin, xmax );
2586 TH1* mva_eff_tr_b = new TH1D( GetTestvarName() + "_trainingEffB", GetTestvarName() + " (background)",
2587 fNbinsH, xmin, xmax );
2588 results->Store(mva_eff_tr_s, "MVA_TRAINEFF_S");
2589 results->Store(mva_eff_tr_b, "MVA_TRAINEFF_B");
2590
2591 // sign if cut
2592 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2593
2594 std::vector<Double_t> mvaValues = GetMvaValues(0,Data()->GetNEvents());
2595 assert( (Long64_t) mvaValues.size() == Data()->GetNEvents());
2596
2597 // this method is unbinned
2598 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2599
2600 Data()->SetCurrentEvent(ievt);
2601 const Event* ev = GetEvent();
2602
2603 Double_t theVal = mvaValues[ievt];
2604 Double_t theWeight = ev->GetWeight();
2605
2606 TH1* theEffHist = DataInfo().IsSignal(ev) ? mva_eff_tr_s : mva_eff_tr_b;
2607 TH1* theClsHist = DataInfo().IsSignal(ev) ? mva_s_tr : mva_b_tr;
2608
2609 theClsHist->Fill( theVal, theWeight );
2610
2611 TAxis* axis = theEffHist->GetXaxis();
2612 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2613 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2614 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2615 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2616 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2617
2618 if (sign > 0) for (Int_t ibin=1; ibin<=maxbin; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2619 else for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2620 }
2621
2622 // normalise output distributions
2623 // uncomment those (and several others if you want unnormalized output
2624 gTools().NormHist( mva_s_tr );
2625 gTools().NormHist( mva_b_tr );
2626
2627 // renormalise to maximum
2628 mva_eff_tr_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_s->GetMaximum()) );
2629 mva_eff_tr_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_b->GetMaximum()) );
2630
2631 // Training background efficiency versus signal efficiency
2632 TH1* eff_bvss = new TH1D( GetTestvarName() + "_trainingEffBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2633 // Training background rejection (=1-eff.) versus signal efficiency
2634 TH1* rej_bvss = new TH1D( GetTestvarName() + "_trainingRejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2635 results->Store(eff_bvss, "EFF_BVSS_TR");
2636 results->Store(rej_bvss, "REJ_BVSS_TR");
2637
2638 // use root finder
2639 // spline background efficiency plot
2640 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2642 if (fSplTrainRefS) delete fSplTrainRefS;
2643 if (fSplTrainRefB) delete fSplTrainRefB;
2644 fSplTrainRefS = new TSpline1( "spline2_signal", new TGraph( mva_eff_tr_s ) );
2645 fSplTrainRefB = new TSpline1( "spline2_background", new TGraph( mva_eff_tr_b ) );
2646
2647 // verify spline sanity
2648 gTools().CheckSplines( mva_eff_tr_s, fSplTrainRefS );
2649 gTools().CheckSplines( mva_eff_tr_b, fSplTrainRefB );
2650 }
2651
2652 // make the background-vs-signal efficiency plot
2653
2654 // create root finder
2655 RootFinder rootFinder(this, fXmin, fXmax );
2656
2657 Double_t effB = 0;
2658 fEffS = results->GetHist("MVA_TRAINEFF_S");
2659 for (Int_t bini=1; bini<=fNbins; bini++) {
2660
2661 // find cut value corresponding to a given signal efficiency
2662 Double_t effS = eff_bvss->GetBinCenter( bini );
2663
2664 Double_t cut = rootFinder.Root( effS );
2665
2666 // retrieve background efficiency for given cut
2667 if (Use_Splines_for_Eff_) effB = fSplTrainRefB->Eval( cut );
2668 else effB = mva_eff_tr_b->GetBinContent( mva_eff_tr_b->FindBin( cut ) );
2669
2670 // and fill histograms
2671 eff_bvss->SetBinContent( bini, effB );
2672 rej_bvss->SetBinContent( bini, 1.0-effB );
2673 }
2674 fEffS = 0;
2675
2676 // create splines for histogram
2677 fSplTrainEffBvsS = new TSpline1( "effBvsS", new TGraph( eff_bvss ) );
2678 }
2679
2680 // must exist...
2681 if (0 == fSplTrainEffBvsS) return 0.0;
2682
2683 // now find signal efficiency that corresponds to required background efficiency
2684 Double_t effS = 0., effB, effS_ = 0., effB_ = 0.;
2685 Int_t nbins_ = 1000;
2686 for (Int_t bini=1; bini<=nbins_; bini++) {
2687
2688 // get corresponding signal and background efficiencies
2689 effS = (bini - 0.5)/Float_t(nbins_);
2690 effB = fSplTrainEffBvsS->Eval( effS );
2691
2692 // find signal efficiency that corresponds to required background efficiency
2693 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2694 effS_ = effS;
2695 effB_ = effB;
2696 }
2697
2698 return 0.5*(effS + effS_); // the mean between bin above and bin below
2699}
2700
2701////////////////////////////////////////////////////////////////////////////////
2702
2703std::vector<Float_t> TMVA::MethodBase::GetMulticlassEfficiency(std::vector<std::vector<Float_t> >& purity)
2704{
2705 Data()->SetCurrentType(Types::kTesting);
2706 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
2707 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in GetMulticlassEfficiency, exiting."<<Endl;
2708
2709 purity.push_back(resMulticlass->GetAchievablePur());
2710 return resMulticlass->GetAchievableEff();
2711}
2712
2713////////////////////////////////////////////////////////////////////////////////
2714
2715std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity)
2716{
2717 Data()->SetCurrentType(Types::kTraining);
2718 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTraining, Types::kMulticlass));
2719 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in GetMulticlassTrainingEfficiency, exiting."<<Endl;
2720
2721 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for training data..." << Endl;
2722 for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
2723 resMulticlass->GetBestMultiClassCuts(icls);
2724 }
2725
2726 purity.push_back(resMulticlass->GetAchievablePur());
2727 return resMulticlass->GetAchievableEff();
2728}
2729
2730////////////////////////////////////////////////////////////////////////////////
2731/// Construct a confusion matrix for a multiclass classifier. The confusion
2732/// matrix compares, in turn, each class agaist all other classes in a pair-wise
2733/// fashion. In rows with index \f$ k_r = 0 ... K \f$, \f$ k_r \f$ is
2734/// considered signal for the sake of comparison and for each column
2735/// \f$ k_c = 0 ... K \f$ the corresponding class is considered background.
2736///
2737/// Note that the diagonal elements will be returned as NaN since this will
2738/// compare a class against itself.
2739///
2740/// \see TMVA::ResultsMulticlass::GetConfusionMatrix
2741///
2742/// \param[in] effB The background efficiency for which to evaluate.
2743/// \param[in] type The data set on which to evaluate (training, testing ...).
2744///
2745/// \return A matrix containing signal efficiencies for the given background
2746/// efficiency. The diagonal elements are NaN since this measure is
2747/// meaningless (comparing a class against itself).
2748///
2749
2751{
2752 if (GetAnalysisType() != Types::kMulticlass) {
2753 Log() << kFATAL << "Cannot get confusion matrix for non-multiclass analysis." << std::endl;
2754 return TMatrixD(0, 0);
2755 }
2756
2757 Data()->SetCurrentType(type);
2758 ResultsMulticlass *resMulticlass =
2759 dynamic_cast<ResultsMulticlass *>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
2760
2761 if (resMulticlass == nullptr) {
2762 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
2763 << "unable to create pointer in GetMulticlassEfficiency, exiting." << Endl;
2764 return TMatrixD(0, 0);
2765 }
2766
2767 return resMulticlass->GetConfusionMatrix(effB);
2768}
2769
2770////////////////////////////////////////////////////////////////////////////////
2771/// compute significance of mean difference
2772/// \f[
2773/// significance = \frac{|<S> - <B>|}{\sqrt{RMS_{S2} + RMS_{B2}}}
2774/// \f]
2775
2777{
2778 Double_t rms = sqrt( fRmsS*fRmsS + fRmsB*fRmsB );
2779
2780 return (rms > 0) ? TMath::Abs(fMeanS - fMeanB)/rms : 0;
2781}
2782
2783////////////////////////////////////////////////////////////////////////////////
2784/// compute "separation" defined as
2785/// \f[
2786/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2787/// \f]
2788
2790{
2791 return gTools().GetSeparation( histoS, histoB );
2792}
2793
2794////////////////////////////////////////////////////////////////////////////////
2795/// compute "separation" defined as
2796/// \f[
2797/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2798/// \f]
2799
2801{
2802 // note, if zero pointers given, use internal pdf
2803 // sanity check first
2804 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2805 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2806 if (!pdfS) pdfS = fSplS;
2807 if (!pdfB) pdfB = fSplB;
2808
2809 if (!fSplS || !fSplB) {
2810 Log()<<kDEBUG<<Form("[%s] : ",DataInfo().GetName())<< "could not calculate the separation, distributions"
2811 << " fSplS or fSplB are not yet filled" << Endl;
2812 return 0;
2813 }else{
2814 return gTools().GetSeparation( *pdfS, *pdfB );
2815 }
2816}
2817
2818////////////////////////////////////////////////////////////////////////////////
2819/// calculate the area (integral) under the ROC curve as a
2820/// overall quality measure of the classification
2821
2823{
2824 // note, if zero pointers given, use internal pdf
2825 // sanity check first
2826 if ((!histS && histB) || (histS && !histB))
2827 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetROCIntegral(TH1D*, TH1D*)> Mismatch in hists" << Endl;
2828
2829 if (histS==0 || histB==0) return 0.;
2830
2831 TMVA::PDF *pdfS = new TMVA::PDF( " PDF Sig", histS, TMVA::PDF::kSpline3 );
2832 TMVA::PDF *pdfB = new TMVA::PDF( " PDF Bkg", histB, TMVA::PDF::kSpline3 );
2833
2834
2835 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2836 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2837
2838 Double_t integral = 0;
2839 UInt_t nsteps = 1000;
2840 Double_t step = (xmax-xmin)/Double_t(nsteps);
2841 Double_t cut = xmin;
2842 for (UInt_t i=0; i<nsteps; i++) {
2843 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2844 cut+=step;
2845 }
2846 delete pdfS;
2847 delete pdfB;
2848 return integral*step;
2849}
2850
2851
2852////////////////////////////////////////////////////////////////////////////////
2853/// calculate the area (integral) under the ROC curve as a
2854/// overall quality measure of the classification
2855
2857{
2858 // note, if zero pointers given, use internal pdf
2859 // sanity check first
2860 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2861 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2862 if (!pdfS) pdfS = fSplS;
2863 if (!pdfB) pdfB = fSplB;
2864
2865 if (pdfS==0 || pdfB==0) return 0.;
2866
2867 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2868 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2869
2870 Double_t integral = 0;
2871 UInt_t nsteps = 1000;
2872 Double_t step = (xmax-xmin)/Double_t(nsteps);
2873 Double_t cut = xmin;
2874 for (UInt_t i=0; i<nsteps; i++) {
2875 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2876 cut+=step;
2877 }
2878 return integral*step;
2879}
2880
2881////////////////////////////////////////////////////////////////////////////////
2882/// plot significance, \f$ \frac{S}{\sqrt{S^2 + B^2}} \f$, curve for given number
2883/// of signal and background events; returns cut for maximum significance
2884/// also returned via reference is the maximum significance
2885
2887 Double_t BackgroundEvents,
2888 Double_t& max_significance_value ) const
2889{
2890 Results* results = Data()->GetResults( GetMethodName(), Types::kTesting, Types::kMaxAnalysisType );
2891
2892 Double_t max_significance(0);
2893 Double_t effS(0),effB(0),significance(0);
2894 TH1D *temp_histogram = new TH1D("temp", "temp", fNbinsH, fXmin, fXmax );
2895
2896 if (SignalEvents <= 0 || BackgroundEvents <= 0) {
2897 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetMaximumSignificance> "
2898 << "Number of signal or background events is <= 0 ==> abort"
2899 << Endl;
2900 }
2901
2902 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Using ratio SignalEvents/BackgroundEvents = "
2903 << SignalEvents/BackgroundEvents << Endl;
2904
2905 TH1* eff_s = results->GetHist("MVA_EFF_S");
2906 TH1* eff_b = results->GetHist("MVA_EFF_B");
2907
2908 if ( (eff_s==0) || (eff_b==0) ) {
2909 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Efficiency histograms empty !" << Endl;
2910 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "no maximum cut found, return 0" << Endl;
2911 return 0;
2912 }
2913
2914 for (Int_t bin=1; bin<=fNbinsH; bin++) {
2915 effS = eff_s->GetBinContent( bin );
2916 effB = eff_b->GetBinContent( bin );
2917
2918 // put significance into a histogram
2919 significance = sqrt(SignalEvents)*( effS )/sqrt( effS + ( BackgroundEvents / SignalEvents) * effB );
2920
2921 temp_histogram->SetBinContent(bin,significance);
2922 }
2923
2924 // find maximum in histogram
2925 max_significance = temp_histogram->GetBinCenter( temp_histogram->GetMaximumBin() );
2926 max_significance_value = temp_histogram->GetBinContent( temp_histogram->GetMaximumBin() );
2927
2928 // delete
2929 delete temp_histogram;
2930
2931 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Optimal cut at : " << max_significance << Endl;
2932 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "Maximum significance: " << max_significance_value << Endl;
2933
2934 return max_significance;
2935}
2936
2937////////////////////////////////////////////////////////////////////////////////
2938/// calculates rms,mean, xmin, xmax of the event variable
2939/// this can be either done for the variables as they are or for
2940/// normalised variables (in the range of 0-1) if "norm" is set to kTRUE
2941
2943 Double_t& meanS, Double_t& meanB,
2944 Double_t& rmsS, Double_t& rmsB,
2946{
2947 Types::ETreeType previousTreeType = Data()->GetCurrentType();
2948 Data()->SetCurrentType(treeType);
2949
2950 Long64_t entries = Data()->GetNEvents();
2951
2952 // sanity check
2953 if (entries <=0)
2954 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CalculateEstimator> Wrong tree type: " << treeType << Endl;
2955
2956 // index of the wanted variable
2957 UInt_t varIndex = DataInfo().FindVarIndex( theVarName );
2958
2959 // first fill signal and background in arrays before analysis
2960 xmin = +DBL_MAX;
2961 xmax = -DBL_MAX;
2962
2963 // take into account event weights
2964 meanS = 0;
2965 meanB = 0;
2966 rmsS = 0;
2967 rmsB = 0;
2968 Double_t sumwS = 0, sumwB = 0;
2969
2970 // loop over all training events
2971 for (Int_t ievt = 0; ievt < entries; ievt++) {
2972
2973 const Event* ev = GetEvent(ievt);
2974
2975 Double_t theVar = ev->GetValue(varIndex);
2976 Double_t weight = ev->GetWeight();
2977
2978 if (DataInfo().IsSignal(ev)) {
2979 sumwS += weight;
2980 meanS += weight*theVar;
2981 rmsS += weight*theVar*theVar;
2982 }
2983 else {
2984 sumwB += weight;
2985 meanB += weight*theVar;
2986 rmsB += weight*theVar*theVar;
2987 }
2988 xmin = TMath::Min( xmin, theVar );
2989 xmax = TMath::Max( xmax, theVar );
2990 }
2991
2992 meanS = meanS/sumwS;
2993 meanB = meanB/sumwB;
2994 rmsS = TMath::Sqrt( rmsS/sumwS - meanS*meanS );
2995 rmsB = TMath::Sqrt( rmsB/sumwB - meanB*meanB );
2996
2997 Data()->SetCurrentType(previousTreeType);
2998}
2999
3000////////////////////////////////////////////////////////////////////////////////
3001/// create reader class for method (classification only at present)
3002
3003void TMVA::MethodBase::MakeClass( const TString& theClassFileName ) const
3004{
3005 // the default consists of
3006 TString classFileName = "";
3007 if (theClassFileName == "")
3008 classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class.C";
3009 else
3010 classFileName = theClassFileName;
3011
3012 TString className = TString("Read") + GetMethodName();
3013
3014 TString tfname( classFileName );
3015 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
3016 << "Creating standalone class: "
3017 << gTools().Color("lightblue") << classFileName << gTools().Color("reset") << Endl;
3018
3019 std::ofstream fout( classFileName );
3020 if (!fout.good()) { // file could not be opened --> Error
3021 Log() << kFATAL << "<MakeClass> Unable to open file: " << classFileName << Endl;
3022 }
3023
3024 // now create the class
3025 // preamble
3026 fout << "// Class: " << className << std::endl;
3027 fout << "// Automatically generated by MethodBase::MakeClass" << std::endl << "//" << std::endl;
3028
3029 // print general information and configuration state
3030 fout << std::endl;
3031 fout << "/* configuration options =====================================================" << std::endl << std::endl;
3032 WriteStateToStream( fout );
3033 fout << std::endl;
3034 fout << "============================================================================ */" << std::endl;
3035
3036 // generate the class
3037 fout << "" << std::endl;
3038 fout << "#include <array>" << std::endl;
3039 fout << "#include <vector>" << std::endl;
3040 fout << "#include <cmath>" << std::endl;
3041 fout << "#include <string>" << std::endl;
3042 fout << "#include <iostream>" << std::endl;
3043 fout << "" << std::endl;
3044 // now if the classifier needs to write some additional classes for its response implementation
3045 // this code goes here: (at least the header declarations need to come before the main class
3046 this->MakeClassSpecificHeader( fout, className );
3047
3048 fout << "#ifndef IClassifierReader__def" << std::endl;
3049 fout << "#define IClassifierReader__def" << std::endl;
3050 fout << std::endl;
3051 fout << "class IClassifierReader {" << std::endl;
3052 fout << std::endl;
3053 fout << " public:" << std::endl;
3054 fout << std::endl;
3055 fout << " // constructor" << std::endl;
3056 fout << " IClassifierReader() : fStatusIsClean( true ) {}" << std::endl;
3057 fout << " virtual ~IClassifierReader() {}" << std::endl;
3058 fout << std::endl;
3059 fout << " // return classifier response" << std::endl;
3060 if(GetAnalysisType() == Types::kMulticlass) {
3061 fout << " virtual std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3062 } else {
3063 fout << " virtual double GetMvaValue( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3064 }
3065 fout << std::endl;
3066 fout << " // returns classifier status" << std::endl;
3067 fout << " bool IsStatusClean() const { return fStatusIsClean; }" << std::endl;
3068 fout << std::endl;
3069 fout << " protected:" << std::endl;
3070 fout << std::endl;
3071 fout << " bool fStatusIsClean;" << std::endl;
3072 fout << "};" << std::endl;
3073 fout << std::endl;
3074 fout << "#endif" << std::endl;
3075 fout << std::endl;
3076 fout << "class " << className << " : public IClassifierReader {" << std::endl;
3077 fout << std::endl;
3078 fout << " public:" << std::endl;
3079 fout << std::endl;
3080 fout << " // constructor" << std::endl;
3081 fout << " " << className << "( std::vector<std::string>& theInputVars )" << std::endl;
3082 fout << " : IClassifierReader()," << std::endl;
3083 fout << " fClassName( \"" << className << "\" )," << std::endl;
3084 fout << " fNvars( " << GetNvar() << " )" << std::endl;
3085 fout << " {" << std::endl;
3086 fout << " // the training input variables" << std::endl;
3087 fout << " const char* inputVars[] = { ";
3088 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3089 fout << "\"" << GetOriginalVarName(ivar) << "\"";
3090 if (ivar<GetNvar()-1) fout << ", ";
3091 }
3092 fout << " };" << std::endl;
3093 fout << std::endl;
3094 fout << " // sanity checks" << std::endl;
3095 fout << " if (theInputVars.size() <= 0) {" << std::endl;
3096 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": empty input vector\" << std::endl;" << std::endl;
3097 fout << " fStatusIsClean = false;" << std::endl;
3098 fout << " }" << std::endl;
3099 fout << std::endl;
3100 fout << " if (theInputVars.size() != fNvars) {" << std::endl;
3101 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in number of input values: \"" << std::endl;
3102 fout << " << theInputVars.size() << \" != \" << fNvars << std::endl;" << std::endl;
3103 fout << " fStatusIsClean = false;" << std::endl;
3104 fout << " }" << std::endl;
3105 fout << std::endl;
3106 fout << " // validate input variables" << std::endl;
3107 fout << " for (size_t ivar = 0; ivar < theInputVars.size(); ivar++) {" << std::endl;
3108 fout << " if (theInputVars[ivar] != inputVars[ivar]) {" << std::endl;
3109 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in input variable names\" << std::endl" << std::endl;
3110 fout << " << \" for variable [\" << ivar << \"]: \" << theInputVars[ivar].c_str() << \" != \" << inputVars[ivar] << std::endl;" << std::endl;
3111 fout << " fStatusIsClean = false;" << std::endl;
3112 fout << " }" << std::endl;
3113 fout << " }" << std::endl;
3114 fout << std::endl;
3115 fout << " // initialize min and max vectors (for normalisation)" << std::endl;
3116 for (UInt_t ivar = 0; ivar < GetNvar(); ivar++) {
3117 fout << " fVmin[" << ivar << "] = " << std::setprecision(15) << GetXmin( ivar ) << ";" << std::endl;
3118 fout << " fVmax[" << ivar << "] = " << std::setprecision(15) << GetXmax( ivar ) << ";" << std::endl;
3119 }
3120 fout << std::endl;
3121 fout << " // initialize input variable types" << std::endl;
3122 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3123 fout << " fType[" << ivar << "] = \'" << DataInfo().GetVariableInfo(ivar).GetVarType() << "\';" << std::endl;
3124 }
3125 fout << std::endl;
3126 fout << " // initialize constants" << std::endl;
3127 fout << " Initialize();" << std::endl;
3128 fout << std::endl;
3129 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
3130 fout << " // initialize transformation" << std::endl;
3131 fout << " InitTransform();" << std::endl;
3132 }
3133 fout << " }" << std::endl;
3134 fout << std::endl;
3135 fout << " // destructor" << std::endl;
3136 fout << " virtual ~" << className << "() {" << std::endl;
3137 fout << " Clear(); // method-specific" << std::endl;
3138 fout << " }" << std::endl;
3139 fout << std::endl;
3140 fout << " // the classifier response" << std::endl;
3141 fout << " // \"inputValues\" is a vector of input values in the same order as the" << std::endl;
3142 fout << " // variables given to the constructor" << std::endl;
3143 if(GetAnalysisType() == Types::kMulticlass) {
3144 fout << " std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const override;" << std::endl;
3145 } else {
3146 fout << " double GetMvaValue( const std::vector<double>& inputValues ) const override;" << std::endl;
3147 }
3148 fout << std::endl;
3149 fout << " private:" << std::endl;
3150 fout << std::endl;
3151 fout << " // method-specific destructor" << std::endl;
3152 fout << " void Clear();" << std::endl;
3153 fout << std::endl;
3154 if (GetTransformationHandler().GetTransformationList().GetSize()!=0) {
3155 fout << " // input variable transformation" << std::endl;
3156 GetTransformationHandler().MakeFunction(fout, className,1);
3157 fout << " void InitTransform();" << std::endl;
3158 fout << " void Transform( std::vector<double> & iv, int sigOrBgd ) const;" << std::endl;
3159 fout << std::endl;
3160 }
3161 fout << " // common member variables" << std::endl;
3162 fout << " const char* fClassName;" << std::endl;
3163 fout << std::endl;
3164 fout << " const size_t fNvars;" << std::endl;
3165 fout << " size_t GetNvar() const { return fNvars; }" << std::endl;
3166 fout << " char GetType( int ivar ) const { return fType[ivar]; }" << std::endl;
3167 fout << std::endl;
3168 fout << " // normalisation of input variables" << std::endl;
3169 fout << " double fVmin[" << GetNvar() << "];" << std::endl;
3170 fout << " double fVmax[" << GetNvar() << "];" << std::endl;
3171 fout << " double NormVariable( double x, double xmin, double xmax ) const {" << std::endl;
3172 fout << " // normalise to output range: [-1, 1]" << std::endl;
3173 fout << " return 2*(x - xmin)/(xmax - xmin) - 1.0;" << std::endl;
3174 fout << " }" << std::endl;
3175 fout << std::endl;
3176 fout << " // type of input variable: 'F' or 'I'" << std::endl;
3177 fout << " char fType[" << GetNvar() << "];" << std::endl;
3178 fout << std::endl;
3179 fout << " // initialize internal variables" << std::endl;
3180 fout << " void Initialize();" << std::endl;
3181 if(GetAnalysisType() == Types::kMulticlass) {
3182 fout << " std::vector<double> GetMulticlassValues__( const std::vector<double>& inputValues ) const;" << std::endl;
3183 } else {
3184 fout << " double GetMvaValue__( const std::vector<double>& inputValues ) const;" << std::endl;
3185 }
3186 fout << "" << std::endl;
3187 fout << " // private members (method specific)" << std::endl;
3188
3189 // call the classifier specific output (the classifier must close the class !)
3190 MakeClassSpecific( fout, className );
3191
3192 if(GetAnalysisType() == Types::kMulticlass) {
3193 fout << "inline std::vector<double> " << className << "::GetMulticlassValues( const std::vector<double>& inputValues ) const" << std::endl;
3194 } else {
3195 fout << "inline double " << className << "::GetMvaValue( const std::vector<double>& inputValues ) const" << std::endl;
3196 }
3197 fout << "{" << std::endl;
3198 fout << " // classifier response value" << std::endl;
3199 if(GetAnalysisType() == Types::kMulticlass) {
3200 fout << " std::vector<double> retval;" << std::endl;
3201 } else {
3202 fout << " double retval = 0;" << std::endl;
3203 }
3204 fout << std::endl;
3205 fout << " // classifier response, sanity check first" << std::endl;
3206 fout << " if (!IsStatusClean()) {" << std::endl;
3207 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": cannot return classifier response\"" << std::endl;
3208 fout << " << \" because status is dirty\" << std::endl;" << std::endl;
3209 fout << " }" << std::endl;
3210 fout << " else {" << std::endl;
3211 if (IsNormalised()) {
3212 fout << " // normalise variables" << std::endl;
3213 fout << " std::vector<double> iV;" << std::endl;
3214 fout << " iV.reserve(inputValues.size());" << std::endl;
3215 fout << " int ivar = 0;" << std::endl;
3216 fout << " for (std::vector<double>::const_iterator varIt = inputValues.begin();" << std::endl;
3217 fout << " varIt != inputValues.end(); varIt++, ivar++) {" << std::endl;
3218 fout << " iV.push_back(NormVariable( *varIt, fVmin[ivar], fVmax[ivar] ));" << std::endl;
3219 fout << " }" << std::endl;
3220 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3221 GetMethodType() != Types::kHMatrix) {
3222 fout << " Transform( iV, -1 );" << std::endl;
3223 }
3224
3225 if(GetAnalysisType() == Types::kMulticlass) {
3226 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3227 } else {
3228 fout << " retval = GetMvaValue__( iV );" << std::endl;
3229 }
3230 } else {
3231 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3232 GetMethodType() != Types::kHMatrix) {
3233 fout << " std::vector<double> iV(inputValues);" << std::endl;
3234 fout << " Transform( iV, -1 );" << std::endl;
3235 if(GetAnalysisType() == Types::kMulticlass) {
3236 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3237 } else {
3238 fout << " retval = GetMvaValue__( iV );" << std::endl;
3239 }
3240 } else {
3241 if(GetAnalysisType() == Types::kMulticlass) {
3242 fout << " retval = GetMulticlassValues__( inputValues );" << std::endl;
3243 } else {
3244 fout << " retval = GetMvaValue__( inputValues );" << std::endl;
3245 }
3246 }
3247 }
3248 fout << " }" << std::endl;
3249 fout << std::endl;
3250 fout << " return retval;" << std::endl;
3251 fout << "}" << std::endl;
3252
3253 // create output for transformation - if any
3254 if (GetTransformationHandler().GetTransformationList().GetSize()!=0)
3255 GetTransformationHandler().MakeFunction(fout, className,2);
3256
3257 // close the file
3258 fout.close();
3259}
3260
3261////////////////////////////////////////////////////////////////////////////////
3262/// prints out method-specific help method
3263
3265{
3266 // if options are written to reference file, also append help info
3267 std::streambuf* cout_sbuf = std::cout.rdbuf(); // save original sbuf
3268 std::ofstream* o = 0;
3269 if (gConfig().WriteOptionsReference()) {
3270 Log() << kINFO << "Print Help message for class " << GetName() << " into file: " << GetReferenceFile() << Endl;
3271 o = new std::ofstream( GetReferenceFile(), std::ios::app );
3272 if (!o->good()) { // file could not be opened --> Error
3273 Log() << kFATAL << "<PrintHelpMessage> Unable to append to output file: " << GetReferenceFile() << Endl;
3274 }
3275 std::cout.rdbuf( o->rdbuf() ); // redirect 'std::cout' to file
3276 }
3277
3278 // "|--------------------------------------------------------------|"
3279 if (!o) {
3280 Log() << kINFO << Endl;
3281 Log() << gTools().Color("bold")
3282 << "================================================================"
3283 << gTools().Color( "reset" )
3284 << Endl;
3285 Log() << gTools().Color("bold")
3286 << "H e l p f o r M V A m e t h o d [ " << GetName() << " ] :"
3287 << gTools().Color( "reset" )
3288 << Endl;
3289 }
3290 else {
3291 Log() << "Help for MVA method [ " << GetName() << " ] :" << Endl;
3292 }
3293
3294 // print method-specific help message
3295 GetHelpMessage();
3296
3297 if (!o) {
3298 Log() << Endl;
3299 Log() << "<Suppress this message by specifying \"!H\" in the booking option>" << Endl;
3300 Log() << gTools().Color("bold")
3301 << "================================================================"
3302 << gTools().Color( "reset" )
3303 << Endl;
3304 Log() << Endl;
3305 }
3306 else {
3307 // indicate END
3308 Log() << "# End of Message___" << Endl;
3309 }
3310
3311 std::cout.rdbuf( cout_sbuf ); // restore the original stream buffer
3312 if (o) o->close();
3313}
3314
3315// ----------------------- r o o t f i n d i n g ----------------------------
3316
3317////////////////////////////////////////////////////////////////////////////////
3318/// returns efficiency as function of cut
3319
3321{
3322 Double_t retval=0;
3323
3324 // retrieve the class object
3326 retval = fSplRefS->Eval( theCut );
3327 }
3328 else retval = fEffS->GetBinContent( fEffS->FindBin( theCut ) );
3329
3330 // caution: here we take some "forbidden" action to hide a problem:
3331 // in some cases, in particular for likelihood, the binned efficiency distributions
3332 // do not equal 1, at xmin, and 0 at xmax; of course, in principle we have the
3333 // unbinned information available in the trees, but the unbinned minimization is
3334 // too slow, and we don't need to do a precision measurement here. Hence, we force
3335 // this property.
3336 Double_t eps = 1.0e-5;
3337 if (theCut-fXmin < eps) retval = (GetCutOrientation() == kPositive) ? 1.0 : 0.0;
3338 else if (fXmax-theCut < eps) retval = (GetCutOrientation() == kPositive) ? 0.0 : 1.0;
3339
3340 return retval;
3341}
3342
3343////////////////////////////////////////////////////////////////////////////////
3344/// returns the event collection (i.e. the dataset) TRANSFORMED using the
3345/// classifiers specific Variable Transformation (e.g. Decorr or Decorr:Gauss:Decorr)
3346
3348{
3349 // if there's no variable transformation for this classifier, just hand back the
3350 // event collection of the data set
3351 if (GetTransformationHandler().GetTransformationList().GetEntries() <= 0) {
3352 return (Data()->GetEventCollection(type));
3353 }
3354
3355 // otherwise, transform ALL the events and hand back the vector of the pointers to the
3356 // transformed events. If the pointer is already != 0, i.e. the whole thing has been
3357 // done before, I don't need to do it again, but just "hand over" the pointer to those events.
3358 Int_t idx = Data()->TreeIndex(type); //index indicating Training,Testing,... events/datasets
3359 if (fEventCollections.at(idx) == 0) {
3360 fEventCollections.at(idx) = &(Data()->GetEventCollection(type));
3361 fEventCollections.at(idx) = GetTransformationHandler().CalcTransformations(*(fEventCollections.at(idx)),kTRUE);
3362 }
3363 return *(fEventCollections.at(idx));
3364}
3365
3366////////////////////////////////////////////////////////////////////////////////
3367/// calculates the TMVA version string from the training version code on the fly
3368
3370{
3371 UInt_t a = GetTrainingTMVAVersionCode() & 0xff0000; a>>=16;
3372 UInt_t b = GetTrainingTMVAVersionCode() & 0x00ff00; b>>=8;
3373 UInt_t c = GetTrainingTMVAVersionCode() & 0x0000ff;
3374
3375 return TString::Format("%i.%i.%i",a,b,c);
3376}
3377
3378////////////////////////////////////////////////////////////////////////////////
3379/// calculates the ROOT version string from the training version code on the fly
3380
3382{
3383 UInt_t a = GetTrainingROOTVersionCode() & 0xff0000; a>>=16;
3384 UInt_t b = GetTrainingROOTVersionCode() & 0x00ff00; b>>=8;
3385 UInt_t c = GetTrainingROOTVersionCode() & 0x0000ff;
3386
3387 return TString::Format("%i.%02i/%02i",a,b,c);
3388}
3389
3390////////////////////////////////////////////////////////////////////////////////
3391
3393 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
3394 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
3395
3396 if (mvaRes != NULL) {
3397 TH1D *mva_s = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_S"));
3398 TH1D *mva_b = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_B"));
3399 TH1D *mva_s_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_S"));
3400 TH1D *mva_b_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_B"));
3401
3402 if ( !mva_s || !mva_b || !mva_s_tr || !mva_b_tr) return -1;
3403
3404 if (SorB == 's' || SorB == 'S')
3405 return mva_s->KolmogorovTest( mva_s_tr, opt.Data() );
3406 else
3407 return mva_b->KolmogorovTest( mva_b_tr, opt.Data() );
3408 }
3409 return -1;
3410}
const Bool_t Use_Splines_for_Eff_
const Int_t NBIN_HIST_HIGH
#define d(i)
Definition RSha256.hxx:102
#define b(i)
Definition RSha256.hxx:100
#define c(i)
Definition RSha256.hxx:101
#define a(i)
Definition RSha256.hxx:99
#define s1(x)
Definition RSha256.hxx:91
#define ROOT_VERSION_CODE
Definition RVersion.hxx:24
bool Bool_t
Definition RtypesCore.h:63
int Int_t
Definition RtypesCore.h:45
char Char_t
Definition RtypesCore.h:37
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
double Double_t
Definition RtypesCore.h:59
long long Long64_t
Definition RtypesCore.h:80
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:377
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char y2
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Option_t Option_t TPoint TPoint const char y1
char name[80]
Definition TGX11.cxx:110
float xmin
float xmax
TMatrixT< Double_t > TMatrixD
Definition TMatrixDfwd.h:23
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
R__EXTERN TSystem * gSystem
Definition TSystem.h:555
#define TMVA_VERSION_CODE
Definition Version.h:47
Class to manage histogram axis.
Definition TAxis.h:31
Double_t GetXmax() const
Definition TAxis.h:140
Double_t GetXmin() const
Definition TAxis.h:139
Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0) override
Write all objects in this collection.
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
This class stores the date and time with a precision of one second in an unsigned 32 bit word (950130...
Definition TDatime.h:37
const char * AsString() const
Return the date & time as a string (ctime() format).
Definition TDatime.cxx:102
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
virtual Bool_t cd()
Change current directory to "this" directory.
virtual TDirectory * mkdir(const char *name, const char *title="", Bool_t returnExistingDirectory=kFALSE)
Create a sub-directory "a" or a hierarchy of sub-directories "a/b/c/...".
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4082
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:943
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
1-D histogram with a double per channel (see TH1 documentation)
Definition TH1.h:669
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:621
TH1 is the base class of all histogram classes in ROOT.
Definition TH1.h:59
virtual Double_t GetBinCenter(Int_t bin) const
Return bin center for 1D histogram.
Definition TH1.cxx:9109
virtual void AddBinContent(Int_t bin)
Increment bin content by 1.
Definition TH1.cxx:1268
virtual Double_t GetMean(Int_t axis=1) const
For axis = 1,2 or 3 returns the mean value of the histogram along X,Y or Z axis.
Definition TH1.cxx:7503
virtual void SetXTitle(const char *title)
Definition TH1.h:418
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition TH1.cxx:1294
TAxis * GetXaxis()
Definition TH1.h:324
virtual Double_t GetMaximum(Double_t maxval=FLT_MAX) const
Return maximum value smaller than maxval of bins in the range, unless the value has been overridden b...
Definition TH1.cxx:8513
virtual Int_t GetNbinsX() const
Definition TH1.h:297
virtual Int_t Fill(Double_t x)
Increment bin with abscissa X by 1.
Definition TH1.cxx:3344
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition TH1.cxx:9190
virtual Int_t GetMaximumBin() const
Return location of bin with maximum value in the range.
Definition TH1.cxx:8545
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition TH1.cxx:5029
virtual Int_t GetQuantiles(Int_t nprobSum, Double_t *q, const Double_t *probSum=nullptr)
Compute Quantiles for this histogram Quantile x_q of a probability distribution Function F is defined...
Definition TH1.cxx:4579
virtual void SetYTitle(const char *title)
Definition TH1.h:419
virtual void Scale(Double_t c1=1, Option_t *option="")
Multiply this histogram by a constant c1.
Definition TH1.cxx:6572
virtual Int_t FindBin(Double_t x, Double_t y=0, Double_t z=0)
Return Global bin number corresponding to x,y,z.
Definition TH1.cxx:3672
virtual Double_t KolmogorovTest(const TH1 *h2, Option_t *option="") const
Statistical test of compatibility in shape between this histogram and h2, using Kolmogorov test.
Definition TH1.cxx:8146
virtual void Sumw2(Bool_t flag=kTRUE)
Create structure to store sum of squares of weights.
Definition TH1.cxx:8988
static Bool_t AddDirectoryStatus()
Static function: cannot be inlined on Windows/NT.
Definition TH1.cxx:754
2-D histogram with a float per channel (see TH1 documentation)
Definition TH2.h:307
Int_t Fill(Double_t) override
Invalid Fill method.
Definition TH2.cxx:393
A doubly linked list.
Definition TList.h:38
TObject * At(Int_t idx) const override
Returns the object at position idx. Returns 0 if idx is out of range.
Definition TList.cxx:355
Class that contains all the information of a class.
Definition ClassInfo.h:49
UInt_t GetNumber() const
Definition ClassInfo.h:65
TString fWeightFileExtension
Definition Config.h:125
VariablePlotting & GetVariablePlotting()
Definition Config.h:97
class TMVA::Config::VariablePlotting fVariablePlotting
IONames & GetIONames()
Definition Config.h:98
MsgLogger * fLogger
! message logger
Class that contains all the data information.
Definition DataSetInfo.h:62
Class that contains all the data information.
Definition DataSet.h:58
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition Event.cxx:389
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:399
Float_t GetTarget(UInt_t itgt) const
Definition Event.h:102
static void SetIgnoreNegWeightsInTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:408
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
IPythonInteractive()
standard constructor
~IPythonInteractive()
standard destructor
void ClearGraphs()
This function sets the point number to 0 for all graphs.
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
Virtual base Class for all MVA method.
Definition MethodBase.h:111
TDirectory * MethodBaseDir() const
returns the ROOT directory where all instances of the corresponding MVA method are stored
virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X")
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
virtual Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr)=0
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void ReadClassesFromXML(void *clsnode)
read number of classes from XML
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
void DeclareBaseOptions()
define the options (their key words) that can be set in the option string here the options valid for ...
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
virtual Double_t GetSignificance() const
compute significance of mean difference
virtual Double_t GetProba(const Event *ev)
const char * GetName() const
Definition MethodBase.h:334
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
const std::vector< TMVA::Event * > & GetEventCollection(Types::ETreeType type)
returns the event collection (i.e.
virtual std::vector< Double_t > GetDataMvaValues(DataSet *data=nullptr, Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the given Data type
void SetupMethod()
setup of methods
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
void AddInfoItem(void *gi, const TString &name, const TString &value) const
xml writing
virtual void AddClassifierOutputProb(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
TString GetTrainingTMVAVersionString() const
calculates the TMVA version string from the training version code on the fly
void Statistics(Types::ETreeType treeType, const TString &theVarName, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &)
calculates rms,mean, xmin, xmax of the event variable this can be either done for the variables as th...
Bool_t GetLine(std::istream &fin, char *buf)
reads one line from the input stream checks for certain keywords and interprets the line if keywords ...
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
virtual Bool_t IsSignalLike()
uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation) for...
virtual ~MethodBase()
destructor
virtual Double_t GetMaximumSignificance(Double_t SignalEvents, Double_t BackgroundEvents, Double_t &optimal_significance_value) const
plot significance, , curve for given number of signal and background events; returns cut for maximum ...
virtual Double_t GetTrainingEfficiency(const TString &)
void SetWeightFileName(TString)
set the weight file name (depreciated)
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
TString GetWeightFileName() const
retrieve weight file name
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
virtual void WriteMonitoringHistosToFile() const
write special monitoring histograms to file dummy implementation here --------------—
virtual void AddRegressionOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
void InitBase()
default initialization called by all constructors
virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t &stddev, Double_t &stddev90Percent) const
void ReadStateFromXMLString(const char *xmlstr)
for reading from memory
void CreateMVAPdfs()
Create PDFs of the MVA output variables.
TString GetTrainingROOTVersionString() const
calculates the ROOT version string from the training version code on the fly
virtual Double_t GetValueForRoot(Double_t)
returns efficiency as function of cut
void ReadStateFromFile()
Function to write options and weights to file.
void WriteVarsToStream(std::ostream &tf, const TString &prefix="") const
write the list of variables (name, min, max) for a given data transformation method to the stream
void ReadVarsFromStream(std::istream &istr)
Read the variables (name, min, max) for a given data transformation method from the stream.
void ReadSpectatorsFromXML(void *specnode)
read spectator info from XML
void SetTestvarName(const TString &v="")
Definition MethodBase.h:341
void ReadVariablesFromXML(void *varnode)
read variable info from XML
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
virtual std::vector< Float_t > GetMulticlassTrainingEfficiency(std::vector< std::vector< Float_t > > &purity)
void WriteStateToStream(std::ostream &tf) const
general method used in writing the header of the weight files where the used variables,...
virtual Double_t GetRarity(Double_t mvaVal, Types::ESBType reftype=Types::kBackground) const
compute rarity:
virtual void SetTuneParameters(std::map< TString, Double_t > tuneParameters)
set the tuning parameters according to the argument This is just a dummy .
void ReadStateFromStream(std::istream &tf)
read the header from the weight files of the different MVA methods
void AddVarsXMLTo(void *parent) const
write variable info to XML
void AddTargetsXMLTo(void *parent) const
write target info to XML
void ReadTargetsFromXML(void *tarnode)
read target info from XML
void ProcessBaseOptions()
the option string is decoded, for available options see "DeclareOptions"
void ReadStateFromXML(void *parent)
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
void WriteStateToFile() const
write options and weights to file note that each one text file for the main configuration information...
void AddClassesXMLTo(void *parent) const
write class info to XML
virtual void AddClassifierOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
void AddSpectatorsXMLTo(void *parent) const
write spectator info to XML
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void AddMulticlassOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
void SetSource(const std::string &source)
Definition MsgLogger.h:68
PDF wrapper for histograms; uses user-defined spline interpolation.
Definition PDF.h:63
Double_t GetXmin() const
Definition PDF.h:104
Double_t GetXmax() const
Definition PDF.h:105
Double_t GetVal(Double_t x) const
returns value PDF(x)
Definition PDF.cxx:701
@ kSpline3
Definition PDF.h:70
@ kSpline2
Definition PDF.h:70
Double_t GetIntegral(Double_t xmin, Double_t xmax)
computes PDF integral within given ranges
Definition PDF.cxx:654
Class that is the base-class for a vector of result.
std::vector< Float_t > * GetValueVector()
void SetValue(Float_t value, Int_t ievt, Bool_t type)
set MVA response
Class which takes the results of a multiclass classification.
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
Float_t GetAchievablePur(UInt_t cls)
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
Float_t GetAchievableEff(UInt_t cls)
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
Class that is the base-class for a vector of result.
Class that is the base-class for a vector of result.
Definition Results.h:57
Bool_t DoesExist(const TString &alias) const
Returns true if there is an object stored in the result for a given alias, false otherwise.
Definition Results.cxx:127
void Store(TObject *obj, const char *alias=nullptr)
Definition Results.cxx:86
TH1 * GetHist(const TString &alias) const
Definition Results.cxx:136
TList * GetStorage() const
Definition Results.h:72
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
Definition RootFinder.h:48
Double_t Root(Double_t refValue)
Root finding using Brents algorithm; taken from CERNLIB function RZERO.
Linear interpolation of TGraph.
Definition TSpline1.h:43
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
Double_t ElapsedSeconds(void)
computes elapsed tim in seconds
Definition Timer.cxx:137
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition Timer.cxx:146
void DrawProgressBar(Int_t, const TString &comment="")
draws progress bar in color or B&W caution:
Definition Timer.cxx:202
void ComputeStat(const std::vector< TMVA::Event * > &, std::vector< Float_t > *, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Int_t signalClass, Bool_t norm=kFALSE)
sanity check
Definition Tools.cxx:202
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition Tools.cxx:401
Double_t GetSeparation(TH1 *S, TH1 *B) const
compute "separation" defined as
Definition Tools.cxx:121
Double_t GetMutualInformation(const TH2F &)
Mutual Information method for non-linear correlations estimates in 2D histogram Author: Moritz Backes...
Definition Tools.cxx:589
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
TXMLEngine & xmlengine()
Definition Tools.h:262
Bool_t CheckSplines(const TH1 *, const TSpline *)
check quality of splining by comparing splines and histograms in each bin
Definition Tools.cxx:479
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition Tools.cxx:383
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition Types.h:135
@ kBackground
Definition Types.h:136
@ kLikelihood
Definition Types.h:79
@ kHMatrix
Definition Types.h:81
@ kMulticlass
Definition Types.h:129
@ kNoAnalysisType
Definition Types.h:130
@ kClassification
Definition Types.h:127
@ kMaxAnalysisType
Definition Types.h:131
@ kRegression
Definition Types.h:128
@ kTraining
Definition Types.h:143
Linear interpolation class.
Gaussian Transformation of input variables.
Class for type info of MVA input variable.
void ReadFromXML(void *varnode)
read VariableInfo from stream
const TString & GetExpression() const
char GetVarType() const
void ReadFromStream(std::istream &istr)
read VariableInfo from stream
void AddToXML(void *varnode)
write class to XML
void SetExternalLink(void *p)
void * GetExternalLink() const
void BuildTransformationFromVarInfo(const std::vector< TMVA::VariableInfo > &var)
this method is only used when building a normalization transformation from old text files in this cas...
Linear interpolation class.
Linear interpolation class.
virtual void ReadTransformationFromStream(std::istream &istr, const TString &classname="")=0
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:34
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Collectable string class.
Definition TObjString.h:28
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition TObject.cxx:880
Basic string class.
Definition TString.h:139
Ssiz_t Length() const
Definition TString.h:417
void ToLower()
Change string to lower-case.
Definition TString.cxx:1182
Int_t Atoi() const
Return integer value of string.
Definition TString.cxx:1988
Bool_t EndsWith(const char *pat, ECaseCompare cmp=kExact) const
Return true if string ends with the specified string.
Definition TString.cxx:2244
TSubString Strip(EStripType s=kTrailing, char c=' ') const
Return a substring of self stripped at beginning and/or end.
Definition TString.cxx:1163
const char * Data() const
Definition TString.h:376
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:704
@ kLeading
Definition TString.h:276
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition TString.cxx:931
Bool_t IsNull() const
Definition TString.h:414
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition TString.h:651
virtual const char * GetBuildNode() const
Return the build node name.
Definition TSystem.cxx:3907
virtual int mkdir(const char *name, Bool_t recursive=kFALSE)
Make a file system directory.
Definition TSystem.cxx:906
virtual const char * WorkingDirectory()
Return working directory.
Definition TSystem.cxx:871
virtual UserGroup_t * GetUserInfo(Int_t uid)
Returns all user info in the UserGroup_t structure.
Definition TSystem.cxx:1601
void SaveDoc(XMLDocPointer_t xmldoc, const char *filename, Int_t layout=1)
store document content to file if layout<=0, no any spaces or newlines will be placed between xmlnode...
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
XMLDocPointer_t NewDoc(const char *version="1.0")
creates new xml document with provided version
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLDocPointer_t ParseString(const char *xmlstring)
parses content of string and tries to produce xml structures
void DocSetRootElement(XMLDocPointer_t xmldoc, XMLNodePointer_t xmlnode)
set main (root) node for document
TLine * line
Double_t x[n]
Definition legend1.C:17
TH1F * h1
Definition legend1.C:5
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Short_t Max(Short_t a, Short_t b)
Returns the largest of a and b.
Definition TMathBase.h:250
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:662
Short_t Min(Short_t a, Short_t b)
Returns the smallest of a and b.
Definition TMathBase.h:198
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:123
TString fUser
Definition TSystem.h:139