Logo ROOT  
Reference Guide
MethodBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
16 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
19 * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20 * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21 * *
22 * Copyright (c) 2005-2011: *
23 * CERN, Switzerland *
24 * U. of Victoria, Canada *
25 * MPI-K Heidelberg, Germany *
26 * U. of Bonn, Germany *
27 * *
28 * Redistribution and use in source and binary forms, with or without *
29 * modification, are permitted according to the terms listed in LICENSE *
30 * (http://tmva.sourceforge.net/LICENSE) *
31 * *
32 **********************************************************************************/
33
34/*! \class TMVA::MethodBase
35\ingroup TMVA
36
37 Virtual base Class for all MVA method
38
39 MethodBase hosts several specific evaluation methods.
40
41 The kind of MVA that provides optimal performance in an analysis strongly
42 depends on the particular application. The evaluation factory provides a
43 number of numerical benchmark results to directly assess the performance
44 of the MVA training on the independent test sample. These are:
45
46 - The _signal efficiency_ at three representative background efficiencies
47 (which is 1 &minus; rejection).
48 - The _significance_ of an MVA estimator, defined by the difference
49 between the MVA mean values for signal and background, divided by the
50 quadratic sum of their root mean squares.
51 - The _separation_ of an MVA _x_, defined by the integral
52 \f[
53 \frac{1}{2} \int \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx
54 \f]
55 where
56 \f$ S(x) \f$ and \f$ B(x) \f$ are the signal and background distributions,
57 respectively. The separation is zero for identical signal and background MVA
58 shapes, and it is one for disjunctive shapes.
59 - The average, \f$ \int x \mu (S(x)) dx \f$, of the signal \f$ \mu_{transform} \f$.
60 The \f$ \mu_{transform} \f$ of an MVA denotes the transformation that yields
61 a uniform background distribution. In this way, the signal distributions
62 \f$ S(x) \f$ can be directly compared among the various MVAs. The stronger
63 \f$ S(x) \f$ peaks towards one, the better is the discrimination of the MVA.
64 The \f$ \mu_{transform} \f$ is
65 [documented here](http://tel.ccsd.cnrs.fr/documents/archives0/00/00/29/91/index_fr.html).
66
67 The MVA standard output also prints the linear correlation coefficients between
68 signal and background, which can be useful to eliminate variables that exhibit too
69 strong correlations.
70*/
71
72#include "TMVA/MethodBase.h"
73
74#include "TMVA/Config.h"
75#include "TMVA/Configurable.h"
76#include "TMVA/DataSetInfo.h"
77#include "TMVA/DataSet.h"
78#include "TMVA/Factory.h"
79#include "TMVA/IMethod.h"
80#include "TMVA/MsgLogger.h"
81#include "TMVA/PDF.h"
82#include "TMVA/Ranking.h"
83#include "TMVA/DataLoader.h"
84#include "TMVA/Tools.h"
85#include "TMVA/Results.h"
89#include "TMVA/RootFinder.h"
90#include "TMVA/Timer.h"
91#include "TMVA/TSpline1.h"
92#include "TMVA/Types.h"
96#include "TMVA/VariableInfo.h"
100#include "TMVA/Version.h"
101
102#include "TROOT.h"
103#include "TSystem.h"
104#include "TObjString.h"
105#include "TQObject.h"
106#include "TSpline.h"
107#include "TMatrix.h"
108#include "TMath.h"
109#include "TH1F.h"
110#include "TH2F.h"
111#include "TFile.h"
112#include "TKey.h"
113#include "TGraph.h"
114#include "Riostream.h"
115#include "TXMLEngine.h"
116
117#include <iomanip>
118#include <iostream>
119#include <fstream>
120#include <sstream>
121#include <cstdlib>
122#include <algorithm>
123#include <limits>
124
125
127
128using std::endl;
129using std::atof;
130
131//const Int_t MethodBase_MaxIterations_ = 200;
133
134//const Int_t NBIN_HIST_PLOT = 100;
135const Int_t NBIN_HIST_HIGH = 10000;
136
137#ifdef _WIN32
138/* Disable warning C4355: 'this' : used in base member initializer list */
139#pragma warning ( disable : 4355 )
140#endif
141
142
143#include "TMultiGraph.h"
144
145////////////////////////////////////////////////////////////////////////////////
146/// standard constructor
147
149{
150 fNumGraphs = 0;
151 fIndex = 0;
152}
153
154////////////////////////////////////////////////////////////////////////////////
155/// standard destructor
157{
158 if (fMultiGraph){
159 delete fMultiGraph;
160 fMultiGraph = nullptr;
161 }
162 return;
163}
164
165////////////////////////////////////////////////////////////////////////////////
166/// This function gets some title and it creates a TGraph for every title.
167/// It also sets up the style for every TGraph. All graphs are added to a single TMultiGraph.
168///
169/// \param[in] graphTitles vector of titles
170
171void TMVA::IPythonInteractive::Init(std::vector<TString>& graphTitles)
172{
173 if (fNumGraphs!=0){
174 std::cerr << kERROR << "IPythonInteractive::Init: already initialized..." << std::endl;
175 return;
176 }
177 Int_t color = 2;
178 for(auto& title : graphTitles){
179 fGraphs.push_back( new TGraph() );
180 fGraphs.back()->SetTitle(title);
181 fGraphs.back()->SetName(title);
182 fGraphs.back()->SetFillColor(color);
183 fGraphs.back()->SetLineColor(color);
184 fGraphs.back()->SetMarkerColor(color);
185 fMultiGraph->Add(fGraphs.back());
186 color += 2;
187 fNumGraphs += 1;
188 }
189 return;
190}
191
192////////////////////////////////////////////////////////////////////////////////
193/// This function sets the point number to 0 for all graphs.
194
196{
197 for(Int_t i=0; i<fNumGraphs; i++){
198 fGraphs[i]->Set(0);
199 }
200}
202////////////////////////////////////////////////////////////////////////////////
203/// This function is used only in 2 TGraph case, and it will add new data points to graphs.
204///
205/// \param[in] x the x coordinate
206/// \param[in] y1 the y coordinate for the first TGraph
207/// \param[in] y2 the y coordinate for the second TGraph
208
210{
211 fGraphs[0]->Set(fIndex+1);
212 fGraphs[1]->Set(fIndex+1);
213 fGraphs[0]->SetPoint(fIndex, x, y1);
214 fGraphs[1]->SetPoint(fIndex, x, y2);
215 fIndex++;
216 return;
217}
218
219////////////////////////////////////////////////////////////////////////////////
220/// This function can add data points to as many TGraphs as we have.
221///
222/// \param[in] dat vector of data points. The dat[0] contains the x coordinate,
223/// dat[1] contains the y coordinate for first TGraph, dat[2] for second, ...
224
225void TMVA::IPythonInteractive::AddPoint(std::vector<Double_t>& dat)
226{
227 for(Int_t i=0; i<fNumGraphs;i++){
228 fGraphs[i]->Set(fIndex+1);
229 fGraphs[i]->SetPoint(fIndex, dat[0], dat[i+1]);
230 }
231 fIndex++;
232 return;
233}
234
235
236////////////////////////////////////////////////////////////////////////////////
237/// standard constructor
238
240 Types::EMVA methodType,
241 const TString& methodTitle,
242 DataSetInfo& dsi,
243 const TString& theOption) :
244 IMethod(),
245 Configurable ( theOption ),
246 fTmpEvent ( 0 ),
247 fRanking ( 0 ),
248 fInputVars ( 0 ),
249 fAnalysisType ( Types::kNoAnalysisType ),
250 fRegressionReturnVal ( 0 ),
251 fMulticlassReturnVal ( 0 ),
252 fDataSetInfo ( dsi ),
253 fSignalReferenceCut ( 0.5 ),
254 fSignalReferenceCutOrientation( 1. ),
255 fVariableTransformType ( Types::kSignal ),
256 fJobName ( jobName ),
257 fMethodName ( methodTitle ),
258 fMethodType ( methodType ),
259 fTestvar ( "" ),
260 fTMVATrainingVersion ( TMVA_VERSION_CODE ),
261 fROOTTrainingVersion ( ROOT_VERSION_CODE ),
262 fConstructedFromWeightFile ( kFALSE ),
263 fBaseDir ( 0 ),
264 fMethodBaseDir ( 0 ),
265 fFile ( 0 ),
266 fSilentFile (kFALSE),
267 fModelPersistence (kTRUE),
268 fWeightFile ( "" ),
269 fEffS ( 0 ),
270 fDefaultPDF ( 0 ),
271 fMVAPdfS ( 0 ),
272 fMVAPdfB ( 0 ),
273 fSplS ( 0 ),
274 fSplB ( 0 ),
275 fSpleffBvsS ( 0 ),
276 fSplTrainS ( 0 ),
277 fSplTrainB ( 0 ),
278 fSplTrainEffBvsS ( 0 ),
279 fVarTransformString ( "None" ),
280 fTransformationPointer ( 0 ),
281 fTransformation ( dsi, methodTitle ),
282 fVerbose ( kFALSE ),
283 fVerbosityLevelString ( "Default" ),
284 fHelp ( kFALSE ),
285 fHasMVAPdfs ( kFALSE ),
286 fIgnoreNegWeightsInTraining( kFALSE ),
287 fSignalClass ( 0 ),
288 fBackgroundClass ( 0 ),
289 fSplRefS ( 0 ),
290 fSplRefB ( 0 ),
291 fSplTrainRefS ( 0 ),
292 fSplTrainRefB ( 0 ),
293 fSetupCompleted (kFALSE)
294{
297
298// // default extension for weight files
299}
300
301////////////////////////////////////////////////////////////////////////////////
302/// constructor used for Testing + Application of the MVA,
303/// only (no training), using given WeightFiles
304
306 DataSetInfo& dsi,
307 const TString& weightFile ) :
308 IMethod(),
309 Configurable(""),
310 fTmpEvent ( 0 ),
311 fRanking ( 0 ),
312 fInputVars ( 0 ),
313 fAnalysisType ( Types::kNoAnalysisType ),
314 fRegressionReturnVal ( 0 ),
315 fMulticlassReturnVal ( 0 ),
316 fDataSetInfo ( dsi ),
317 fSignalReferenceCut ( 0.5 ),
318 fVariableTransformType ( Types::kSignal ),
319 fJobName ( "" ),
320 fMethodName ( "MethodBase" ),
321 fMethodType ( methodType ),
322 fTestvar ( "" ),
323 fTMVATrainingVersion ( 0 ),
324 fROOTTrainingVersion ( 0 ),
325 fConstructedFromWeightFile ( kTRUE ),
326 fBaseDir ( 0 ),
327 fMethodBaseDir ( 0 ),
328 fFile ( 0 ),
329 fSilentFile (kFALSE),
330 fModelPersistence (kTRUE),
331 fWeightFile ( weightFile ),
332 fEffS ( 0 ),
333 fDefaultPDF ( 0 ),
334 fMVAPdfS ( 0 ),
335 fMVAPdfB ( 0 ),
336 fSplS ( 0 ),
337 fSplB ( 0 ),
338 fSpleffBvsS ( 0 ),
339 fSplTrainS ( 0 ),
340 fSplTrainB ( 0 ),
341 fSplTrainEffBvsS ( 0 ),
342 fVarTransformString ( "None" ),
343 fTransformationPointer ( 0 ),
344 fTransformation ( dsi, "" ),
345 fVerbose ( kFALSE ),
346 fVerbosityLevelString ( "Default" ),
347 fHelp ( kFALSE ),
348 fHasMVAPdfs ( kFALSE ),
349 fIgnoreNegWeightsInTraining( kFALSE ),
350 fSignalClass ( 0 ),
351 fBackgroundClass ( 0 ),
352 fSplRefS ( 0 ),
353 fSplRefB ( 0 ),
354 fSplTrainRefS ( 0 ),
355 fSplTrainRefB ( 0 ),
356 fSetupCompleted (kFALSE)
357{
359// // constructor used for Testing + Application of the MVA,
360// // only (no training), using given WeightFiles
361}
362
363////////////////////////////////////////////////////////////////////////////////
364/// destructor
365
367{
368 // destructor
369 if (!fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling destructor of method which got never setup" << Endl;
370
371 // destructor
372 if (fInputVars != 0) { fInputVars->clear(); delete fInputVars; }
373 if (fRanking != 0) delete fRanking;
374
375 // PDFs
376 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
377 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
378 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
379
380 // Splines
381 if (fSplS) { delete fSplS; fSplS = 0; }
382 if (fSplB) { delete fSplB; fSplB = 0; }
383 if (fSpleffBvsS) { delete fSpleffBvsS; fSpleffBvsS = 0; }
384 if (fSplRefS) { delete fSplRefS; fSplRefS = 0; }
385 if (fSplRefB) { delete fSplRefB; fSplRefB = 0; }
386 if (fSplTrainRefS) { delete fSplTrainRefS; fSplTrainRefS = 0; }
387 if (fSplTrainRefB) { delete fSplTrainRefB; fSplTrainRefB = 0; }
388 if (fSplTrainEffBvsS) { delete fSplTrainEffBvsS; fSplTrainEffBvsS = 0; }
389
390 for (Int_t i = 0; i < 2; i++ ) {
391 if (fEventCollections.at(i)) {
392 for (std::vector<Event*>::const_iterator it = fEventCollections.at(i)->begin();
393 it != fEventCollections.at(i)->end(); ++it) {
394 delete (*it);
395 }
396 delete fEventCollections.at(i);
397 fEventCollections.at(i) = 0;
398 }
399 }
400
401 if (fRegressionReturnVal) delete fRegressionReturnVal;
402 if (fMulticlassReturnVal) delete fMulticlassReturnVal;
403}
404
405////////////////////////////////////////////////////////////////////////////////
406/// setup of methods
407
409{
410 // setup of methods
411
412 if (fSetupCompleted) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Calling SetupMethod for the second time" << Endl;
413 InitBase();
414 DeclareBaseOptions();
415 Init();
416 DeclareOptions();
417 fSetupCompleted = kTRUE;
418}
419
420////////////////////////////////////////////////////////////////////////////////
421/// process all options
422/// the "CheckForUnusedOptions" is done in an independent call, since it may be overridden by derived class
423/// (sometimes, eg, fitters are used which can only be implemented during training phase)
424
426{
427 ProcessBaseOptions();
428 ProcessOptions();
429}
430
431////////////////////////////////////////////////////////////////////////////////
432/// check may be overridden by derived class
433/// (sometimes, eg, fitters are used which can only be implemented during training phase)
434
436{
437 CheckForUnusedOptions();
438}
439
440////////////////////////////////////////////////////////////////////////////////
441/// default initialization called by all constructors
442
444{
445 SetConfigDescription( "Configuration options for classifier architecture and tuning" );
446
448 fNbinsMVAoutput = gConfig().fVariablePlotting.fNbinsMVAoutput;
449 fNbinsH = NBIN_HIST_HIGH;
450
451 fSplTrainS = 0;
452 fSplTrainB = 0;
453 fSplTrainEffBvsS = 0;
454 fMeanS = -1;
455 fMeanB = -1;
456 fRmsS = -1;
457 fRmsB = -1;
458 fXmin = DBL_MAX;
459 fXmax = -DBL_MAX;
460 fTxtWeightsOnly = kTRUE;
461 fSplRefS = 0;
462 fSplRefB = 0;
463
464 fTrainTime = -1.;
465 fTestTime = -1.;
466
467 fRanking = 0;
468
469 // temporary until the move to DataSet is complete
470 fInputVars = new std::vector<TString>;
471 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
472 fInputVars->push_back(DataInfo().GetVariableInfo(ivar).GetLabel());
473 }
474 fRegressionReturnVal = 0;
475 fMulticlassReturnVal = 0;
476
477 fEventCollections.resize( 2 );
478 fEventCollections.at(0) = 0;
479 fEventCollections.at(1) = 0;
480
481 // retrieve signal and background class index
482 if (DataInfo().GetClassInfo("Signal") != 0) {
483 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
484 }
485 if (DataInfo().GetClassInfo("Background") != 0) {
486 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
487 }
488
489 SetConfigDescription( "Configuration options for MVA method" );
490 SetConfigName( TString("Method") + GetMethodTypeName() );
491}
492
493////////////////////////////////////////////////////////////////////////////////
494/// define the options (their key words) that can be set in the option string
495/// here the options valid for ALL MVA methods are declared.
496///
497/// know options:
498///
499/// - VariableTransform=None,Decorrelated,PCA to use transformed variables
500/// instead of the original ones
501/// - VariableTransformType=Signal,Background which decorrelation matrix to use
502/// in the method. Only the Likelihood
503/// Method can make proper use of independent
504/// transformations of signal and background
505/// - fNbinsMVAPdf = 50 Number of bins used to create a PDF of MVA
506/// - fNsmoothMVAPdf = 2 Number of times a histogram is smoothed before creating the PDF
507/// - fHasMVAPdfs create PDFs for the MVA outputs
508/// - V for Verbose output (!V) for non verbos
509/// - H for Help message
510
512{
513 DeclareOptionRef( fVerbose, "V", "Verbose output (short form of \"VerbosityLevel\" below - overrides the latter one)" );
514
515 DeclareOptionRef( fVerbosityLevelString="Default", "VerbosityLevel", "Verbosity level" );
516 AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
517 AddPreDefVal( TString("Debug") );
518 AddPreDefVal( TString("Verbose") );
519 AddPreDefVal( TString("Info") );
520 AddPreDefVal( TString("Warning") );
521 AddPreDefVal( TString("Error") );
522 AddPreDefVal( TString("Fatal") );
523
524 // If True (default): write all training results (weights) as text files only;
525 // if False: write also in ROOT format (not available for all methods - will abort if not
526 fTxtWeightsOnly = kTRUE; // OBSOLETE !!!
527 fNormalise = kFALSE; // OBSOLETE !!!
528
529 DeclareOptionRef( fVarTransformString, "VarTransform", "List of variable transformations performed before training, e.g., \"D_Background,P_Signal,G,N_AllClasses\" for: \"Decorrelation, PCA-transformation, Gaussianisation, Normalisation, each for the given class of events ('AllClasses' denotes all events of all classes, if no class indication is given, 'All' is assumed)\"" );
530
531 DeclareOptionRef( fHelp, "H", "Print method-specific help message" );
532
533 DeclareOptionRef( fHasMVAPdfs, "CreateMVAPdfs", "Create PDFs for classifier outputs (signal and background)" );
534
535 DeclareOptionRef( fIgnoreNegWeightsInTraining, "IgnoreNegWeightsInTraining",
536 "Events with negative weights are ignored in the training (but are included for testing and performance evaluation)" );
537}
538
539////////////////////////////////////////////////////////////////////////////////
540/// the option string is decoded, for available options see "DeclareOptions"
541
543{
544 if (HasMVAPdfs()) {
545 // setting the default bin num... maybe should be static ? ==> Please no static (JS)
546 // You can't use the logger in the constructor!!! Log() << kINFO << "Create PDFs" << Endl;
547 // reading every PDF's definition and passing the option string to the next one to be read and marked
548 fDefaultPDF = new PDF( TString(GetName())+"_PDF", GetOptions(), "MVAPdf" );
549 fDefaultPDF->DeclareOptions();
550 fDefaultPDF->ParseOptions();
551 fDefaultPDF->ProcessOptions();
552 fMVAPdfB = new PDF( TString(GetName())+"_PDFBkg", fDefaultPDF->GetOptions(), "MVAPdfBkg", fDefaultPDF );
553 fMVAPdfB->DeclareOptions();
554 fMVAPdfB->ParseOptions();
555 fMVAPdfB->ProcessOptions();
556 fMVAPdfS = new PDF( TString(GetName())+"_PDFSig", fMVAPdfB->GetOptions(), "MVAPdfSig", fDefaultPDF );
557 fMVAPdfS->DeclareOptions();
558 fMVAPdfS->ParseOptions();
559 fMVAPdfS->ProcessOptions();
560
561 // the final marked option string is written back to the original methodbase
562 SetOptions( fMVAPdfS->GetOptions() );
563 }
564
565 TMVA::CreateVariableTransforms( fVarTransformString,
566 DataInfo(),
567 GetTransformationHandler(),
568 Log() );
569
570 if (!HasMVAPdfs()) {
571 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
572 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
573 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
574 }
575
576 if (fVerbose) { // overwrites other settings
577 fVerbosityLevelString = TString("Verbose");
578 Log().SetMinType( kVERBOSE );
579 }
580 else if (fVerbosityLevelString == "Debug" ) Log().SetMinType( kDEBUG );
581 else if (fVerbosityLevelString == "Verbose" ) Log().SetMinType( kVERBOSE );
582 else if (fVerbosityLevelString == "Info" ) Log().SetMinType( kINFO );
583 else if (fVerbosityLevelString == "Warning" ) Log().SetMinType( kWARNING );
584 else if (fVerbosityLevelString == "Error" ) Log().SetMinType( kERROR );
585 else if (fVerbosityLevelString == "Fatal" ) Log().SetMinType( kFATAL );
586 else if (fVerbosityLevelString != "Default" ) {
587 Log() << kFATAL << "<ProcessOptions> Verbosity level type '"
588 << fVerbosityLevelString << "' unknown." << Endl;
589 }
590 Event::SetIgnoreNegWeightsInTraining(fIgnoreNegWeightsInTraining);
591}
592
593////////////////////////////////////////////////////////////////////////////////
594/// options that are used ONLY for the READER to ensure backward compatibility
595/// they are hence without any effect (the reader is only reading the training
596/// options that HAD been used at the training of the .xml weight file at hand
597
599{
600 DeclareOptionRef( fNormalise=kFALSE, "Normalise", "Normalise input variables" ); // don't change the default !!!
601 DeclareOptionRef( fUseDecorr=kFALSE, "D", "Use-decorrelated-variables flag" );
602 DeclareOptionRef( fVariableTransformTypeString="Signal", "VarTransformType",
603 "Use signal or background events to derive for variable transformation (the transformation is applied on both types of, course)" );
604 AddPreDefVal( TString("Signal") );
605 AddPreDefVal( TString("Background") );
606 DeclareOptionRef( fTxtWeightsOnly=kTRUE, "TxtWeightFilesOnly", "If True: write all training results (weights) as text files (False: some are written in ROOT format)" );
607 // Why on earth ?? was this here? Was the verbosity level option meant to 'disappear? Not a good idea i think..
608 // DeclareOptionRef( fVerbosityLevelString="Default", "VerboseLevel", "Verbosity level" );
609 // AddPreDefVal( TString("Default") ); // uses default defined in MsgLogger header
610 // AddPreDefVal( TString("Debug") );
611 // AddPreDefVal( TString("Verbose") );
612 // AddPreDefVal( TString("Info") );
613 // AddPreDefVal( TString("Warning") );
614 // AddPreDefVal( TString("Error") );
615 // AddPreDefVal( TString("Fatal") );
616 DeclareOptionRef( fNbinsMVAPdf = 60, "NbinsMVAPdf", "Number of bins used for the PDFs of classifier outputs" );
617 DeclareOptionRef( fNsmoothMVAPdf = 2, "NsmoothMVAPdf", "Number of smoothing iterations for classifier PDFs" );
618}
619
620
621////////////////////////////////////////////////////////////////////////////////
622/// call the Optimizer with the set of parameters and ranges that
623/// are meant to be tuned.
624
625std::map<TString,Double_t> TMVA::MethodBase::OptimizeTuningParameters(TString /* fomType */ , TString /* fitType */)
626{
627 // this is just a dummy... needs to be implemented for each method
628 // individually (as long as we don't have it automatized via the
629 // configuration string
630
631 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Parameter optimization is not yet implemented for method "
632 << GetName() << Endl;
633 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Currently we need to set hardcoded which parameter is tuned in which ranges"<<Endl;
634
635 std::map<TString,Double_t> tunedParameters;
636 tunedParameters.size(); // just to get rid of "unused" warning
637 return tunedParameters;
638
639}
640
641////////////////////////////////////////////////////////////////////////////////
642/// set the tuning parameters according to the argument
643/// This is just a dummy .. have a look at the MethodBDT how you could
644/// perhaps implement the same thing for the other Classifiers..
645
646void TMVA::MethodBase::SetTuneParameters(std::map<TString,Double_t> /* tuneParameters */)
647{
648}
649
650////////////////////////////////////////////////////////////////////////////////
651
653{
654 Data()->SetCurrentType(Types::kTraining);
655 Event::SetIsTraining(kTRUE); // used to set negative event weights to zero if chosen to do so
656
657 // train the MVA method
658 if (Help()) PrintHelpMessage();
659
660 // all histograms should be created in the method's subdirectory
661 if(!IsSilentFile()) BaseDir()->cd();
662
663 // once calculate all the transformation (e.g. the sequence of Decorr:Gauss:Decorr)
664 // needed for this classifier
665 GetTransformationHandler().CalcTransformations(Data()->GetEventCollection());
666
667 // call training of derived MVA
668 Log() << kDEBUG //<<Form("\tDataset[%s] : ",DataInfo().GetName())
669 << "Begin training" << Endl;
670 Long64_t nEvents = Data()->GetNEvents();
671 Timer traintimer( nEvents, GetName(), kTRUE );
672 Train();
673 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName()
674 << "\tEnd of training " << Endl;
675 SetTrainTime(traintimer.ElapsedSeconds());
676 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
677 << "Elapsed time for training with " << nEvents << " events: "
678 << traintimer.GetElapsedTime() << " " << Endl;
679
680 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
681 << "\tCreate MVA output for ";
682
683 // create PDFs for the signal and background MVA distributions (if required)
684 if (DoMulticlass()) {
685 Log() <<Form("[%s] : ",DataInfo().GetName())<< "Multiclass classification on training sample" << Endl;
686 AddMulticlassOutput(Types::kTraining);
687 }
688 else if (!DoRegression()) {
689
690 Log() <<Form("[%s] : ",DataInfo().GetName())<< "classification on training sample" << Endl;
691 AddClassifierOutput(Types::kTraining);
692 if (HasMVAPdfs()) {
693 CreateMVAPdfs();
694 AddClassifierOutputProb(Types::kTraining);
695 }
696
697 } else {
698
699 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "regression on training sample" << Endl;
700 AddRegressionOutput( Types::kTraining );
701
702 if (HasMVAPdfs() ) {
703 Log() <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create PDFs" << Endl;
704 CreateMVAPdfs();
705 }
706 }
707
708 // write the current MVA state into stream
709 // produced are one text file and one ROOT file
710 if (fModelPersistence ) WriteStateToFile();
711
712 // produce standalone make class (presently only supported for classification)
713 if ((!DoRegression()) && (fModelPersistence)) MakeClass();
714
715 // write additional monitoring histograms to main target file (not the weight file)
716 // again, make sure the histograms go into the method's subdirectory
717 if(!IsSilentFile())
718 {
719 BaseDir()->cd();
720 WriteMonitoringHistosToFile();
721 }
722}
723
724////////////////////////////////////////////////////////////////////////////////
725
727{
728 if (!DoRegression()) Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Trying to use GetRegressionDeviation() with a classification job" << Endl;
729 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
731 bool truncate = false;
732 TH1F* h1 = regRes->QuadraticDeviation( tgtNum , truncate, 1.);
733 stddev = sqrt(h1->GetMean());
734 truncate = true;
735 Double_t yq[1], xq[]={0.9};
736 h1->GetQuantiles(1,yq,xq);
737 TH1F* h2 = regRes->QuadraticDeviation( tgtNum , truncate, yq[0]);
738 stddev90Percent = sqrt(h2->GetMean());
739 delete h1;
740 delete h2;
741}
742
743////////////////////////////////////////////////////////////////////////////////
744/// prepare tree branch with the method's discriminating variable
745
747{
748 Data()->SetCurrentType(type);
749
750 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
751
752 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), type, Types::kRegression);
753
754 Long64_t nEvents = Data()->GetNEvents();
755
756 // use timer
757 Timer timer( nEvents, GetName(), kTRUE );
758 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
759 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
760
761 regRes->Resize( nEvents );
762
763 // Drawing the progress bar every event was causing a huge slowdown in the evaluation time
764 // So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
765
766 Int_t totalProgressDraws = 100; // total number of times to update the progress bar
767 Int_t drawProgressEvery = 1; // draw every nth event such that we have a total of totalProgressDraws
768 if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
769
770 for (Int_t ievt=0; ievt<nEvents; ievt++) {
771
772 Data()->SetCurrentEvent(ievt);
773 std::vector< Float_t > vals = GetRegressionValues();
774 regRes->SetValue( vals, ievt );
775
776 // Only draw the progress bar once in a while, doing this every event causes the evaluation to be ridiculously slow
777 if(ievt % drawProgressEvery == 0 || ievt==nEvents-1) timer.DrawProgressBar( ievt );
778 }
779
780 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
781 << "Elapsed time for evaluation of " << nEvents << " events: "
782 << timer.GetElapsedTime() << " " << Endl;
783
784 // store time used for testing
786 SetTestTime(timer.ElapsedSeconds());
787
788 TString histNamePrefix(GetTestvarName());
789 histNamePrefix += (type==Types::kTraining?"train":"test");
790 regRes->CreateDeviationHistograms( histNamePrefix );
791}
792
793////////////////////////////////////////////////////////////////////////////////
794/// prepare tree branch with the method's discriminating variable
795
797{
798 Data()->SetCurrentType(type);
799
800 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
801
802 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
803 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in AddMulticlassOutput, exiting."<<Endl;
804
805 Long64_t nEvents = Data()->GetNEvents();
806
807 // use timer
808 Timer timer( nEvents, GetName(), kTRUE );
809
810 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Multiclass evaluation of " << GetMethodName() << " on "
811 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
812
813 resMulticlass->Resize( nEvents );
814 for (Int_t ievt=0; ievt<nEvents; ievt++) {
815 Data()->SetCurrentEvent(ievt);
816 std::vector< Float_t > vals = GetMulticlassValues();
817 resMulticlass->SetValue( vals, ievt );
818 timer.DrawProgressBar( ievt );
819 }
820
821 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())
822 << "Elapsed time for evaluation of " << nEvents << " events: "
823 << timer.GetElapsedTime() << " " << Endl;
824
825 // store time used for testing
827 SetTestTime(timer.ElapsedSeconds());
828
829 TString histNamePrefix(GetTestvarName());
830 histNamePrefix += (type==Types::kTraining?"_Train":"_Test");
831
832 resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbinsMVAoutput, fNbinsH );
833 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefix);
834}
835
836////////////////////////////////////////////////////////////////////////////////
837
838void TMVA::MethodBase::NoErrorCalc(Double_t* const err, Double_t* const errUpper) {
839 if (err) *err=-1;
840 if (errUpper) *errUpper=-1;
841}
842
843////////////////////////////////////////////////////////////////////////////////
844
845Double_t TMVA::MethodBase::GetMvaValue( const Event* const ev, Double_t* err, Double_t* errUpper ) {
846 fTmpEvent = ev;
847 Double_t val = GetMvaValue(err, errUpper);
848 fTmpEvent = 0;
849 return val;
850}
851
852////////////////////////////////////////////////////////////////////////////////
853/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
854/// for a quick determination if an event would be selected as signal or background
855
857 return GetMvaValue()*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
858}
859////////////////////////////////////////////////////////////////////////////////
860/// uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation)
861/// for a quick determination if an event with this mva output value would be selected as signal or background
862
864 return mvaVal*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
865}
866
867////////////////////////////////////////////////////////////////////////////////
868/// prepare tree branch with the method's discriminating variable
869
871{
872 Data()->SetCurrentType(type);
873
874 ResultsClassification* clRes =
876
877 Long64_t nEvents = Data()->GetNEvents();
878 clRes->Resize( nEvents );
879
880 // use timer
881 Timer timer( nEvents, GetName(), kTRUE );
882 std::vector<Double_t> mvaValues = GetMvaValues(0, nEvents, true);
883
884 // store time used for testing
886 SetTestTime(timer.ElapsedSeconds());
887
888 // load mva values to results object
889 for (Int_t ievt=0; ievt<nEvents; ievt++) {
890 clRes->SetValue( mvaValues[ievt], ievt );
891 }
892}
893
894////////////////////////////////////////////////////////////////////////////////
895/// get all the MVA values for the events of the current Data type
896std::vector<Double_t> TMVA::MethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
897{
898
899 Long64_t nEvents = Data()->GetNEvents();
900 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
901 if (firstEvt < 0) firstEvt = 0;
902 std::vector<Double_t> values(lastEvt-firstEvt);
903 // log in case of looping on all the events
904 nEvents = values.size();
905
906 // use timer
907 Timer timer( nEvents, GetName(), kTRUE );
908
909 if (logProgress)
910 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
911 << "Evaluation of " << GetMethodName() << " on "
912 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
913 << " sample (" << nEvents << " events)" << Endl;
914
915 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
916 Data()->SetCurrentEvent(ievt);
917 values[ievt] = GetMvaValue();
918
919 // print progress
920 if (logProgress) {
921 Int_t modulo = Int_t(nEvents/100);
922 if (modulo <= 0 ) modulo = 1;
923 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
924 }
925 }
926 if (logProgress) {
927 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
928 << "Elapsed time for evaluation of " << nEvents << " events: "
929 << timer.GetElapsedTime() << " " << Endl;
930 }
931
932 return values;
933}
934
935////////////////////////////////////////////////////////////////////////////////
936/// prepare tree branch with the method's discriminating variable
937
939{
940 Data()->SetCurrentType(type);
941
942 ResultsClassification* mvaProb =
943 (ResultsClassification*)Data()->GetResults(TString("prob_")+GetMethodName(), type, Types::kClassification );
944
945 Long64_t nEvents = Data()->GetNEvents();
946
947 // use timer
948 Timer timer( nEvents, GetName(), kTRUE );
949
950 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName()) << "Evaluation of " << GetMethodName() << " on "
951 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
952
953 mvaProb->Resize( nEvents );
954 for (Int_t ievt=0; ievt<nEvents; ievt++) {
955
956 Data()->SetCurrentEvent(ievt);
957 Float_t proba = ((Float_t)GetProba( GetMvaValue(), 0.5 ));
958 if (proba < 0) break;
959 mvaProb->SetValue( proba, ievt );
960
961 // print progress
962 Int_t modulo = Int_t(nEvents/100);
963 if (modulo <= 0 ) modulo = 1;
964 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
965 }
966
967 Log() << kDEBUG <<Form("Dataset[%s] : ",DataInfo().GetName())
968 << "Elapsed time for evaluation of " << nEvents << " events: "
969 << timer.GetElapsedTime() << " " << Endl;
970}
971
972////////////////////////////////////////////////////////////////////////////////
973/// calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
974///
975/// - bias = average deviation
976/// - dev = average absolute deviation
977/// - rms = rms of deviation
978
980 Double_t& dev, Double_t& devT,
981 Double_t& rms, Double_t& rmsT,
982 Double_t& mInf, Double_t& mInfT,
983 Double_t& corr,
985{
986 Types::ETreeType savedType = Data()->GetCurrentType();
987 Data()->SetCurrentType(type);
988
989 bias = 0; biasT = 0; dev = 0; devT = 0; rms = 0; rmsT = 0;
990 Double_t sumw = 0;
991 Double_t m1 = 0, m2 = 0, s1 = 0, s2 = 0, s12 = 0; // for correlation
992 const Int_t nevt = GetNEvents();
993 Float_t* rV = new Float_t[nevt];
994 Float_t* tV = new Float_t[nevt];
995 Float_t* wV = new Float_t[nevt];
996 Float_t xmin = 1e30, xmax = -1e30;
997 Log() << kINFO << "Calculate regression for all events" << Endl;
998 Timer timer( nevt, GetName(), kTRUE );
999 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1000
1001 const Event* ev = Data()->GetEvent(ievt); // NOTE: need untransformed event here !
1002 Float_t t = ev->GetTarget(0);
1003 Float_t w = ev->GetWeight();
1004 Float_t r = GetRegressionValues()[0];
1005 Float_t d = (r-t);
1006
1007 // find min/max
1010
1011 // store for truncated RMS computation
1012 rV[ievt] = r;
1013 tV[ievt] = t;
1014 wV[ievt] = w;
1015
1016 // compute deviation-squared
1017 sumw += w;
1018 bias += w * d;
1019 dev += w * TMath::Abs(d);
1020 rms += w * d * d;
1021
1022 // compute correlation between target and regression estimate
1023 m1 += t*w; s1 += t*t*w;
1024 m2 += r*w; s2 += r*r*w;
1025 s12 += t*r;
1026 if ((ievt & 0xFF) == 0) timer.DrawProgressBar(ievt);
1027 }
1028 timer.DrawProgressBar(nevt - 1);
1029 Log() << kINFO << "Elapsed time for evaluation of " << nevt << " events: "
1030 << timer.GetElapsedTime() << " " << Endl;
1031
1032 // standard quantities
1033 bias /= sumw;
1034 dev /= sumw;
1035 rms /= sumw;
1036 rms = TMath::Sqrt(rms - bias*bias);
1037
1038 // correlation
1039 m1 /= sumw;
1040 m2 /= sumw;
1041 corr = s12/sumw - m1*m2;
1042 corr /= TMath::Sqrt( (s1/sumw - m1*m1) * (s2/sumw - m2*m2) );
1043
1044 // create histogram required for computation of mutual information
1045 TH2F* hist = new TH2F( "hist", "hist", 150, xmin, xmax, 100, xmin, xmax );
1046 TH2F* histT = new TH2F( "histT", "histT", 150, xmin, xmax, 100, xmin, xmax );
1047
1048 // compute truncated RMS and fill histogram
1049 Double_t devMax = bias + 2*rms;
1050 Double_t devMin = bias - 2*rms;
1051 sumw = 0;
1052 int ic=0;
1053 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1054 Float_t d = (rV[ievt] - tV[ievt]);
1055 hist->Fill( rV[ievt], tV[ievt], wV[ievt] );
1056 if (d >= devMin && d <= devMax) {
1057 sumw += wV[ievt];
1058 biasT += wV[ievt] * d;
1059 devT += wV[ievt] * TMath::Abs(d);
1060 rmsT += wV[ievt] * d * d;
1061 histT->Fill( rV[ievt], tV[ievt], wV[ievt] );
1062 ic++;
1063 }
1064 }
1065 biasT /= sumw;
1066 devT /= sumw;
1067 rmsT /= sumw;
1068 rmsT = TMath::Sqrt(rmsT - biasT*biasT);
1069 mInf = gTools().GetMutualInformation( *hist );
1070 mInfT = gTools().GetMutualInformation( *histT );
1071
1072 delete hist;
1073 delete histT;
1074
1075 delete [] rV;
1076 delete [] tV;
1077 delete [] wV;
1078
1079 Data()->SetCurrentType(savedType);
1080}
1081
1082
1083////////////////////////////////////////////////////////////////////////////////
1084/// test multiclass classification
1085
1087{
1088 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
1089 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in TestMulticlass, exiting."<<Endl;
1090
1091 // GA evaluation of best cut for sig eff * sig pur. Slow, disabled for now.
1092 // Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for test
1093 // data..." << Endl; for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
1094 // resMulticlass->GetBestMultiClassCuts(icls);
1095 // }
1096
1097 // Create histograms for use in TMVA GUI
1098 TString histNamePrefix(GetTestvarName());
1099 TString histNamePrefixTest{histNamePrefix + "_Test"};
1100 TString histNamePrefixTrain{histNamePrefix + "_Train"};
1101
1102 resMulticlass->CreateMulticlassHistos(histNamePrefixTest, fNbinsMVAoutput, fNbinsH);
1103 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTest);
1104
1105 resMulticlass->CreateMulticlassHistos(histNamePrefixTrain, fNbinsMVAoutput, fNbinsH);
1106 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTrain);
1107}
1108
1109
1110////////////////////////////////////////////////////////////////////////////////
1111/// initialization
1112
1114{
1115 Data()->SetCurrentType(Types::kTesting);
1116
1117 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
1118 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
1119
1120 // sanity checks: tree must exist, and theVar must be in tree
1121 if (0==mvaRes && !(GetMethodTypeName().Contains("Cuts"))) {
1122 Log()<<Form("Dataset[%s] : ",DataInfo().GetName()) << "mvaRes " << mvaRes << " GetMethodTypeName " << GetMethodTypeName()
1123 << " contains " << !(GetMethodTypeName().Contains("Cuts")) << Endl;
1124 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<TestInit> Test variable " << GetTestvarName()
1125 << " not found in tree" << Endl;
1126 }
1127
1128 // basic statistics operations are made in base class
1129 gTools().ComputeStat( GetEventCollection(Types::kTesting), mvaRes->GetValueVector(),
1130 fMeanS, fMeanB, fRmsS, fRmsB, fXmin, fXmax, fSignalClass );
1131
1132 // choose reasonable histogram ranges, by removing outliers
1133 Double_t nrms = 10;
1134 fXmin = TMath::Max( TMath::Min( fMeanS - nrms*fRmsS, fMeanB - nrms*fRmsB ), fXmin );
1135 fXmax = TMath::Min( TMath::Max( fMeanS + nrms*fRmsS, fMeanB + nrms*fRmsB ), fXmax );
1136
1137 // determine cut orientation
1138 fCutOrientation = (fMeanS > fMeanB) ? kPositive : kNegative;
1139
1140 // fill 2 types of histograms for the various analyses
1141 // this one is for actual plotting
1142
1143 Double_t sxmax = fXmax+0.00001;
1144
1145 // classifier response distributions for training sample
1146 // MVA plots used for graphics representation (signal)
1147 TString TestvarName;
1148 if(IsSilentFile())
1149 {
1150 TestvarName=Form("[%s]%s",DataInfo().GetName(),GetTestvarName().Data());
1151 }else
1152 {
1153 TestvarName=GetTestvarName();
1154 }
1155 TH1* mva_s = new TH1D( TestvarName + "_S",TestvarName + "_S", fNbinsMVAoutput, fXmin, sxmax );
1156 TH1* mva_b = new TH1D( TestvarName + "_B",TestvarName + "_B", fNbinsMVAoutput, fXmin, sxmax );
1157 mvaRes->Store(mva_s, "MVA_S");
1158 mvaRes->Store(mva_b, "MVA_B");
1159 mva_s->Sumw2();
1160 mva_b->Sumw2();
1161
1162 TH1* proba_s = 0;
1163 TH1* proba_b = 0;
1164 TH1* rarity_s = 0;
1165 TH1* rarity_b = 0;
1166 if (HasMVAPdfs()) {
1167 // P(MVA) plots used for graphics representation
1168 proba_s = new TH1D( TestvarName + "_Proba_S", TestvarName + "_Proba_S", fNbinsMVAoutput, 0.0, 1.0 );
1169 proba_b = new TH1D( TestvarName + "_Proba_B", TestvarName + "_Proba_B", fNbinsMVAoutput, 0.0, 1.0 );
1170 mvaRes->Store(proba_s, "Prob_S");
1171 mvaRes->Store(proba_b, "Prob_B");
1172 proba_s->Sumw2();
1173 proba_b->Sumw2();
1174
1175 // R(MVA) plots used for graphics representation
1176 rarity_s = new TH1D( TestvarName + "_Rarity_S", TestvarName + "_Rarity_S", fNbinsMVAoutput, 0.0, 1.0 );
1177 rarity_b = new TH1D( TestvarName + "_Rarity_B", TestvarName + "_Rarity_B", fNbinsMVAoutput, 0.0, 1.0 );
1178 mvaRes->Store(rarity_s, "Rar_S");
1179 mvaRes->Store(rarity_b, "Rar_B");
1180 rarity_s->Sumw2();
1181 rarity_b->Sumw2();
1182 }
1183
1184 // MVA plots used for efficiency calculations (large number of bins)
1185 TH1* mva_eff_s = new TH1D( TestvarName + "_S_high", TestvarName + "_S_high", fNbinsH, fXmin, sxmax );
1186 TH1* mva_eff_b = new TH1D( TestvarName + "_B_high", TestvarName + "_B_high", fNbinsH, fXmin, sxmax );
1187 mvaRes->Store(mva_eff_s, "MVA_HIGHBIN_S");
1188 mvaRes->Store(mva_eff_b, "MVA_HIGHBIN_B");
1189 mva_eff_s->Sumw2();
1190 mva_eff_b->Sumw2();
1191
1192 // fill the histograms
1193
1194 ResultsClassification* mvaProb = dynamic_cast<ResultsClassification*>
1195 (Data()->GetResults( TString("prob_")+GetMethodName(), Types::kTesting, Types::kMaxAnalysisType ) );
1196
1197 Log() << kHEADER <<Form("[%s] : ",DataInfo().GetName())<< "Loop over test events and fill histograms with classifier response..." << Endl << Endl;
1198 if (mvaProb) Log() << kINFO << "Also filling probability and rarity histograms (on request)..." << Endl;
1199 std::vector<Bool_t>* mvaResTypes = mvaRes->GetValueVectorTypes();
1200
1201 //LM: this is needed to avoid crashes in ROOCCURVE
1202 if ( mvaRes->GetSize() != GetNEvents() ) {
1203 Log() << kFATAL << TString::Format("Inconsistent result size %lld with number of events %u ", mvaRes->GetSize() , GetNEvents() ) << Endl;
1204 assert(mvaRes->GetSize() == GetNEvents());
1205 }
1206
1207 for (Long64_t ievt=0; ievt<GetNEvents(); ievt++) {
1208
1209 const Event* ev = GetEvent(ievt);
1210 Float_t v = (*mvaRes)[ievt][0];
1211 Float_t w = ev->GetWeight();
1212
1213 if (DataInfo().IsSignal(ev)) {
1214 mvaResTypes->push_back(kTRUE);
1215 mva_s ->Fill( v, w );
1216 if (mvaProb) {
1217 proba_s->Fill( (*mvaProb)[ievt][0], w );
1218 rarity_s->Fill( GetRarity( v ), w );
1219 }
1220
1221 mva_eff_s ->Fill( v, w );
1222 }
1223 else {
1224 mvaResTypes->push_back(kFALSE);
1225 mva_b ->Fill( v, w );
1226 if (mvaProb) {
1227 proba_b->Fill( (*mvaProb)[ievt][0], w );
1228 rarity_b->Fill( GetRarity( v ), w );
1229 }
1230 mva_eff_b ->Fill( v, w );
1231 }
1232 }
1233
1234 // uncomment those (and several others if you want unnormalized output
1235 gTools().NormHist( mva_s );
1236 gTools().NormHist( mva_b );
1237 gTools().NormHist( proba_s );
1238 gTools().NormHist( proba_b );
1239 gTools().NormHist( rarity_s );
1240 gTools().NormHist( rarity_b );
1241 gTools().NormHist( mva_eff_s );
1242 gTools().NormHist( mva_eff_b );
1243
1244 // create PDFs from histograms, using default splines, and no additional smoothing
1245 if (fSplS) { delete fSplS; fSplS = 0; }
1246 if (fSplB) { delete fSplB; fSplB = 0; }
1247 fSplS = new PDF( TString(GetName()) + " PDF Sig", mva_s, PDF::kSpline2 );
1248 fSplB = new PDF( TString(GetName()) + " PDF Bkg", mva_b, PDF::kSpline2 );
1249}
1250
1251////////////////////////////////////////////////////////////////////////////////
1252/// general method used in writing the header of the weight files where
1253/// the used variables, variable transformation type etc. is specified
1254
1255void TMVA::MethodBase::WriteStateToStream( std::ostream& tf ) const
1256{
1257 TString prefix = "";
1258 UserGroup_t * userInfo = gSystem->GetUserInfo();
1259
1260 tf << prefix << "#GEN -*-*-*-*-*-*-*-*-*-*-*- general info -*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1261 tf << prefix << "Method : " << GetMethodTypeName() << "::" << GetMethodName() << std::endl;
1262 tf.setf(std::ios::left);
1263 tf << prefix << "TMVA Release : " << std::setw(10) << GetTrainingTMVAVersionString() << " ["
1264 << GetTrainingTMVAVersionCode() << "]" << std::endl;
1265 tf << prefix << "ROOT Release : " << std::setw(10) << GetTrainingROOTVersionString() << " ["
1266 << GetTrainingROOTVersionCode() << "]" << std::endl;
1267 tf << prefix << "Creator : " << userInfo->fUser << std::endl;
1268 tf << prefix << "Date : "; TDatime *d = new TDatime; tf << d->AsString() << std::endl; delete d;
1269 tf << prefix << "Host : " << gSystem->GetBuildNode() << std::endl;
1270 tf << prefix << "Dir : " << gSystem->WorkingDirectory() << std::endl;
1271 tf << prefix << "Training events: " << Data()->GetNTrainingEvents() << std::endl;
1272
1273 TString analysisType(((const_cast<TMVA::MethodBase*>(this)->GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification"));
1274
1275 tf << prefix << "Analysis type : " << "[" << ((GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification") << "]" << std::endl;
1276 tf << prefix << std::endl;
1277
1278 delete userInfo;
1279
1280 // First write all options
1281 tf << prefix << std::endl << prefix << "#OPT -*-*-*-*-*-*-*-*-*-*-*-*- options -*-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1282 WriteOptionsToStream( tf, prefix );
1283 tf << prefix << std::endl;
1284
1285 // Second write variable info
1286 tf << prefix << std::endl << prefix << "#VAR -*-*-*-*-*-*-*-*-*-*-*-* variables *-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1287 WriteVarsToStream( tf, prefix );
1288 tf << prefix << std::endl;
1289}
1290
1291////////////////////////////////////////////////////////////////////////////////
1292/// xml writing
1293
1294void TMVA::MethodBase::AddInfoItem( void* gi, const TString& name, const TString& value) const
1295{
1296 void* it = gTools().AddChild(gi,"Info");
1297 gTools().AddAttr(it,"name", name);
1298 gTools().AddAttr(it,"value", value);
1299}
1300
1301////////////////////////////////////////////////////////////////////////////////
1302
1304 if (analysisType == Types::kRegression) {
1305 AddRegressionOutput( type );
1306 } else if (analysisType == Types::kMulticlass) {
1307 AddMulticlassOutput( type );
1308 } else {
1309 AddClassifierOutput( type );
1310 if (HasMVAPdfs())
1311 AddClassifierOutputProb( type );
1312 }
1313}
1314
1315////////////////////////////////////////////////////////////////////////////////
1316/// general method used in writing the header of the weight files where
1317/// the used variables, variable transformation type etc. is specified
1318
1319void TMVA::MethodBase::WriteStateToXML( void* parent ) const
1320{
1321 if (!parent) return;
1322
1323 UserGroup_t* userInfo = gSystem->GetUserInfo();
1324
1325 void* gi = gTools().AddChild(parent, "GeneralInfo");
1326 AddInfoItem( gi, "TMVA Release", GetTrainingTMVAVersionString() + " [" + gTools().StringFromInt(GetTrainingTMVAVersionCode()) + "]" );
1327 AddInfoItem( gi, "ROOT Release", GetTrainingROOTVersionString() + " [" + gTools().StringFromInt(GetTrainingROOTVersionCode()) + "]");
1328 AddInfoItem( gi, "Creator", userInfo->fUser);
1329 TDatime dt; AddInfoItem( gi, "Date", dt.AsString());
1330 AddInfoItem( gi, "Host", gSystem->GetBuildNode() );
1331 AddInfoItem( gi, "Dir", gSystem->WorkingDirectory());
1332 AddInfoItem( gi, "Training events", gTools().StringFromInt(Data()->GetNTrainingEvents()));
1333 AddInfoItem( gi, "TrainingTime", gTools().StringFromDouble(const_cast<TMVA::MethodBase*>(this)->GetTrainTime()));
1334
1335 Types::EAnalysisType aType = const_cast<TMVA::MethodBase*>(this)->GetAnalysisType();
1336 TString analysisType((aType==Types::kRegression) ? "Regression" :
1337 (aType==Types::kMulticlass ? "Multiclass" : "Classification"));
1338 AddInfoItem( gi, "AnalysisType", analysisType );
1339 delete userInfo;
1340
1341 // write options
1342 AddOptionsXMLTo( parent );
1343
1344 // write variable info
1345 AddVarsXMLTo( parent );
1346
1347 // write spectator info
1348 if (fModelPersistence)
1349 AddSpectatorsXMLTo( parent );
1350
1351 // write class info if in multiclass mode
1352 AddClassesXMLTo(parent);
1353
1354 // write target info if in regression mode
1355 if (DoRegression()) AddTargetsXMLTo(parent);
1356
1357 // write transformations
1358 GetTransformationHandler(false).AddXMLTo( parent );
1359
1360 // write MVA variable distributions
1361 void* pdfs = gTools().AddChild(parent, "MVAPdfs");
1362 if (fMVAPdfS) fMVAPdfS->AddXMLTo(pdfs);
1363 if (fMVAPdfB) fMVAPdfB->AddXMLTo(pdfs);
1364
1365 // write weights
1366 AddWeightsXMLTo( parent );
1367}
1368
1369////////////////////////////////////////////////////////////////////////////////
1370/// write reference MVA distributions (and other information)
1371/// to a ROOT type weight file
1372
1374{
1375 Bool_t addDirStatus = TH1::AddDirectoryStatus();
1376 TH1::AddDirectory( 0 ); // this avoids the binding of the hists in PDF to the current ROOT file
1377 fMVAPdfS = (TMVA::PDF*)rf.Get( "MVA_PDF_Signal" );
1378 fMVAPdfB = (TMVA::PDF*)rf.Get( "MVA_PDF_Background" );
1379
1380 TH1::AddDirectory( addDirStatus );
1381
1382 ReadWeightsFromStream( rf );
1383
1384 SetTestvarName();
1385}
1386
1387////////////////////////////////////////////////////////////////////////////////
1388/// write options and weights to file
1389/// note that each one text file for the main configuration information
1390/// and one ROOT file for ROOT objects are created
1391
1393{
1394 // ---- create the text file
1395 TString tfname( GetWeightFileName() );
1396
1397 // writing xml file
1398 TString xmlfname( tfname ); xmlfname.ReplaceAll( ".txt", ".xml" );
1399 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1400 << "Creating xml weight file: "
1401 << gTools().Color("lightblue") << xmlfname << gTools().Color("reset") << Endl;
1402 void* doc = gTools().xmlengine().NewDoc();
1403 void* rootnode = gTools().AddChild(0,"MethodSetup", "", true);
1404 gTools().xmlengine().DocSetRootElement(doc,rootnode);
1405 gTools().AddAttr(rootnode,"Method", GetMethodTypeName() + "::" + GetMethodName());
1406 WriteStateToXML(rootnode);
1407 gTools().xmlengine().SaveDoc(doc,xmlfname);
1408 gTools().xmlengine().FreeDoc(doc);
1409}
1410
1411////////////////////////////////////////////////////////////////////////////////
1412/// Function to write options and weights to file
1413
1415{
1416 // get the filename
1417
1418 TString tfname(GetWeightFileName());
1419
1420 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
1421 << "Reading weight file: "
1422 << gTools().Color("lightblue") << tfname << gTools().Color("reset") << Endl;
1423
1424 if (tfname.EndsWith(".xml") ) {
1425 void* doc = gTools().xmlengine().ParseFile(tfname,gTools().xmlenginebuffersize()); // the default buffer size in TXMLEngine::ParseFile is 100k. Starting with ROOT 5.29 one can set the buffer size, see: http://savannah.cern.ch/bugs/?78864. This might be necessary for large XML files
1426 if (!doc) {
1427 Log() << kFATAL << "Error parsing XML file " << tfname << Endl;
1428 }
1429 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1430 ReadStateFromXML(rootnode);
1431 gTools().xmlengine().FreeDoc(doc);
1432 }
1433 else {
1434 std::filebuf fb;
1435 fb.open(tfname.Data(),std::ios::in);
1436 if (!fb.is_open()) { // file not found --> Error
1437 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ReadStateFromFile> "
1438 << "Unable to open input weight file: " << tfname << Endl;
1439 }
1440 std::istream fin(&fb);
1441 ReadStateFromStream(fin);
1442 fb.close();
1443 }
1444 if (!fTxtWeightsOnly) {
1445 // ---- read the ROOT file
1446 TString rfname( tfname ); rfname.ReplaceAll( ".txt", ".root" );
1447 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Reading root weight file: "
1448 << gTools().Color("lightblue") << rfname << gTools().Color("reset") << Endl;
1449 TFile* rfile = TFile::Open( rfname, "READ" );
1450 ReadStateFromStream( *rfile );
1451 rfile->Close();
1452 }
1453}
1454////////////////////////////////////////////////////////////////////////////////
1455/// for reading from memory
1456
1458 void* doc = gTools().xmlengine().ParseString(xmlstr);
1459 void* rootnode = gTools().xmlengine().DocGetRootElement(doc); // node "MethodSetup"
1460 ReadStateFromXML(rootnode);
1461 gTools().xmlengine().FreeDoc(doc);
1462
1463 return;
1464}
1465
1466////////////////////////////////////////////////////////////////////////////////
1467
1469{
1470
1471 TString fullMethodName;
1472 gTools().ReadAttr( methodNode, "Method", fullMethodName );
1473
1474 fMethodName = fullMethodName(fullMethodName.Index("::")+2,fullMethodName.Length());
1475
1476 // update logger
1477 Log().SetSource( GetName() );
1478 Log() << kDEBUG//<<Form("Dataset[%s] : ",DataInfo().GetName())
1479 << "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1480
1481 // after the method name is read, the testvar can be set
1482 SetTestvarName();
1483
1484 TString nodeName("");
1485 void* ch = gTools().GetChild(methodNode);
1486 while (ch!=0) {
1487 nodeName = TString( gTools().GetName(ch) );
1488
1489 if (nodeName=="GeneralInfo") {
1490 // read analysis type
1491
1492 TString name(""),val("");
1493 void* antypeNode = gTools().GetChild(ch);
1494 while (antypeNode) {
1495 gTools().ReadAttr( antypeNode, "name", name );
1496
1497 if (name == "TrainingTime")
1498 gTools().ReadAttr( antypeNode, "value", fTrainTime );
1499
1500 if (name == "AnalysisType") {
1501 gTools().ReadAttr( antypeNode, "value", val );
1502 val.ToLower();
1503 if (val == "regression" ) SetAnalysisType( Types::kRegression );
1504 else if (val == "classification" ) SetAnalysisType( Types::kClassification );
1505 else if (val == "multiclass" ) SetAnalysisType( Types::kMulticlass );
1506 else Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Analysis type " << val << " is not known." << Endl;
1507 }
1508
1509 if (name == "TMVA Release" || name == "TMVA") {
1510 TString s;
1511 gTools().ReadAttr( antypeNode, "value", s);
1512 fTMVATrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1513 Log() << kDEBUG <<Form("[%s] : ",DataInfo().GetName()) << "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
1514 }
1515
1516 if (name == "ROOT Release" || name == "ROOT") {
1517 TString s;
1518 gTools().ReadAttr( antypeNode, "value", s);
1519 fROOTTrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
1520 Log() << kDEBUG //<<Form("Dataset[%s] : ",DataInfo().GetName())
1521 << "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
1522 }
1523 antypeNode = gTools().GetNextChild(antypeNode);
1524 }
1525 }
1526 else if (nodeName=="Options") {
1527 ReadOptionsFromXML(ch);
1528 ParseOptions();
1529
1530 }
1531 else if (nodeName=="Variables") {
1532 ReadVariablesFromXML(ch);
1533 }
1534 else if (nodeName=="Spectators") {
1535 ReadSpectatorsFromXML(ch);
1536 }
1537 else if (nodeName=="Classes") {
1538 if (DataInfo().GetNClasses()==0) ReadClassesFromXML(ch);
1539 }
1540 else if (nodeName=="Targets") {
1541 if (DataInfo().GetNTargets()==0 && DoRegression()) ReadTargetsFromXML(ch);
1542 }
1543 else if (nodeName=="Transformations") {
1544 GetTransformationHandler().ReadFromXML(ch);
1545 }
1546 else if (nodeName=="MVAPdfs") {
1547 TString pdfname;
1548 if (fMVAPdfS) { delete fMVAPdfS; fMVAPdfS=0; }
1549 if (fMVAPdfB) { delete fMVAPdfB; fMVAPdfB=0; }
1550 void* pdfnode = gTools().GetChild(ch);
1551 if (pdfnode) {
1552 gTools().ReadAttr(pdfnode, "Name", pdfname);
1553 fMVAPdfS = new PDF(pdfname);
1554 fMVAPdfS->ReadXML(pdfnode);
1555 pdfnode = gTools().GetNextChild(pdfnode);
1556 gTools().ReadAttr(pdfnode, "Name", pdfname);
1557 fMVAPdfB = new PDF(pdfname);
1558 fMVAPdfB->ReadXML(pdfnode);
1559 }
1560 }
1561 else if (nodeName=="Weights") {
1562 ReadWeightsFromXML(ch);
1563 }
1564 else {
1565 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Unparsed XML node: '" << nodeName << "'" << Endl;
1566 }
1567 ch = gTools().GetNextChild(ch);
1568
1569 }
1570
1571 // update transformation handler
1572 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1573}
1574
1575////////////////////////////////////////////////////////////////////////////////
1576/// read the header from the weight files of the different MVA methods
1577
1579{
1580 char buf[512];
1581
1582 // when reading from stream, we assume the files are produced with TMVA<=397
1583 SetAnalysisType(Types::kClassification);
1584
1585
1586 // first read the method name
1587 GetLine(fin,buf);
1588 while (!TString(buf).BeginsWith("Method")) GetLine(fin,buf);
1589 TString namestr(buf);
1590
1591 TString methodType = namestr(0,namestr.Index("::"));
1592 methodType = methodType(methodType.Last(' '),methodType.Length());
1593 methodType = methodType.Strip(TString::kLeading);
1594
1595 TString methodName = namestr(namestr.Index("::")+2,namestr.Length());
1596 methodName = methodName.Strip(TString::kLeading);
1597 if (methodName == "") methodName = methodType;
1598 fMethodName = methodName;
1599
1600 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
1601
1602 // update logger
1603 Log().SetSource( GetName() );
1604
1605 // now the question is whether to read the variables first or the options (well, of course the order
1606 // of writing them needs to agree)
1607 //
1608 // the option "Decorrelation" is needed to decide if the variables we
1609 // read are decorrelated or not
1610 //
1611 // the variables are needed by some methods (TMLP) to build the NN
1612 // which is done in ProcessOptions so for the time being we first Read and Parse the options then
1613 // we read the variables, and then we process the options
1614
1615 // now read all options
1616 GetLine(fin,buf);
1617 while (!TString(buf).BeginsWith("#OPT")) GetLine(fin,buf);
1618 ReadOptionsFromStream(fin);
1619 ParseOptions();
1620
1621 // Now read variable info
1622 fin.getline(buf,512);
1623 while (!TString(buf).BeginsWith("#VAR")) fin.getline(buf,512);
1624 ReadVarsFromStream(fin);
1625
1626 // now we process the options (of the derived class)
1627 ProcessOptions();
1628
1629 if (IsNormalised()) {
1631 GetTransformationHandler().AddTransformation( new VariableNormalizeTransform(DataInfo()), -1 );
1632 norm->BuildTransformationFromVarInfo( DataInfo().GetVariableInfos() );
1633 }
1634 VariableTransformBase *varTrafo(0), *varTrafo2(0);
1635 if ( fVarTransformString == "None") {
1636 if (fUseDecorr)
1637 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1638 } else if ( fVarTransformString == "Decorrelate" ) {
1639 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1640 } else if ( fVarTransformString == "PCA" ) {
1641 varTrafo = GetTransformationHandler().AddTransformation( new VariablePCATransform(DataInfo()), -1 );
1642 } else if ( fVarTransformString == "Uniform" ) {
1643 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo(),"Uniform"), -1 );
1644 } else if ( fVarTransformString == "Gauss" ) {
1645 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1646 } else if ( fVarTransformString == "GaussDecorr" ) {
1647 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
1648 varTrafo2 = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
1649 } else {
1650 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<ProcessOptions> Variable transform '"
1651 << fVarTransformString << "' unknown." << Endl;
1652 }
1653 // Now read decorrelation matrix if available
1654 if (GetTransformationHandler().GetTransformationList().GetSize() > 0) {
1655 fin.getline(buf,512);
1656 while (!TString(buf).BeginsWith("#MAT")) fin.getline(buf,512);
1657 if (varTrafo) {
1658 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1659 varTrafo->ReadTransformationFromStream(fin, trafo );
1660 }
1661 if (varTrafo2) {
1662 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1663 varTrafo2->ReadTransformationFromStream(fin, trafo );
1664 }
1665 }
1666
1667
1668 if (HasMVAPdfs()) {
1669 // Now read the MVA PDFs
1670 fin.getline(buf,512);
1671 while (!TString(buf).BeginsWith("#MVAPDFS")) fin.getline(buf,512);
1672 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
1673 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
1674 fMVAPdfS = new PDF(TString(GetName()) + " MVA PDF Sig");
1675 fMVAPdfB = new PDF(TString(GetName()) + " MVA PDF Bkg");
1676 fMVAPdfS->SetReadingVersion( GetTrainingTMVAVersionCode() );
1677 fMVAPdfB->SetReadingVersion( GetTrainingTMVAVersionCode() );
1678
1679 fin >> *fMVAPdfS;
1680 fin >> *fMVAPdfB;
1681 }
1682
1683 // Now read weights
1684 fin.getline(buf,512);
1685 while (!TString(buf).BeginsWith("#WGT")) fin.getline(buf,512);
1686 fin.getline(buf,512);
1687 ReadWeightsFromStream( fin );;
1688
1689 // update transformation handler
1690 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
1691
1692}
1693
1694////////////////////////////////////////////////////////////////////////////////
1695/// write the list of variables (name, min, max) for a given data
1696/// transformation method to the stream
1697
1698void TMVA::MethodBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const
1699{
1700 o << prefix << "NVar " << DataInfo().GetNVariables() << std::endl;
1701 std::vector<VariableInfo>::const_iterator varIt = DataInfo().GetVariableInfos().begin();
1702 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1703 o << prefix << "NSpec " << DataInfo().GetNSpectators() << std::endl;
1704 varIt = DataInfo().GetSpectatorInfos().begin();
1705 for (; varIt!=DataInfo().GetSpectatorInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1706}
1707
1708////////////////////////////////////////////////////////////////////////////////
1709/// Read the variables (name, min, max) for a given data
1710/// transformation method from the stream. In the stream we only
1711/// expect the limits which will be set
1712
1714{
1715 TString dummy;
1716 UInt_t readNVar;
1717 istr >> dummy >> readNVar;
1718
1719 if (readNVar!=DataInfo().GetNVariables()) {
1720 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1721 << " while there are " << readNVar << " variables declared in the file"
1722 << Endl;
1723 }
1724
1725 // we want to make sure all variables are read in the order they are defined
1726 VariableInfo varInfo;
1727 std::vector<VariableInfo>::iterator varIt = DataInfo().GetVariableInfos().begin();
1728 int varIdx = 0;
1729 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt, ++varIdx) {
1730 varInfo.ReadFromStream(istr);
1731 if (varIt->GetExpression() == varInfo.GetExpression()) {
1732 varInfo.SetExternalLink((*varIt).GetExternalLink());
1733 (*varIt) = varInfo;
1734 }
1735 else {
1736 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVarsFromStream>" << Endl;
1737 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1738 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
1739 Log() << kINFO << "the correct working of the method):" << Endl;
1740 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
1741 Log() << kINFO << " var #" << varIdx <<" declared in file : " << varInfo.GetExpression() << Endl;
1742 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1743 }
1744 }
1745}
1746
1747////////////////////////////////////////////////////////////////////////////////
1748/// write variable info to XML
1749
1750void TMVA::MethodBase::AddVarsXMLTo( void* parent ) const
1751{
1752 void* vars = gTools().AddChild(parent, "Variables");
1753 gTools().AddAttr( vars, "NVar", gTools().StringFromInt(DataInfo().GetNVariables()) );
1754
1755 for (UInt_t idx=0; idx<DataInfo().GetVariableInfos().size(); idx++) {
1756 VariableInfo& vi = DataInfo().GetVariableInfos()[idx];
1757 void* var = gTools().AddChild( vars, "Variable" );
1758 gTools().AddAttr( var, "VarIndex", idx );
1759 vi.AddToXML( var );
1760 }
1761}
1762
1763////////////////////////////////////////////////////////////////////////////////
1764/// write spectator info to XML
1765
1767{
1768 void* specs = gTools().AddChild(parent, "Spectators");
1769
1770 UInt_t writeIdx=0;
1771 for (UInt_t idx=0; idx<DataInfo().GetSpectatorInfos().size(); idx++) {
1772
1773 VariableInfo& vi = DataInfo().GetSpectatorInfos()[idx];
1774
1775 // we do not want to write spectators that are category-cuts,
1776 // except if the method is the category method and the spectators belong to it
1777 if (vi.GetVarType()=='C') continue;
1778
1779 void* spec = gTools().AddChild( specs, "Spectator" );
1780 gTools().AddAttr( spec, "SpecIndex", writeIdx++ );
1781 vi.AddToXML( spec );
1782 }
1783 gTools().AddAttr( specs, "NSpec", gTools().StringFromInt(writeIdx) );
1784}
1785
1786////////////////////////////////////////////////////////////////////////////////
1787/// write class info to XML
1788
1789void TMVA::MethodBase::AddClassesXMLTo( void* parent ) const
1790{
1791 UInt_t nClasses=DataInfo().GetNClasses();
1792
1793 void* classes = gTools().AddChild(parent, "Classes");
1794 gTools().AddAttr( classes, "NClass", nClasses );
1795
1796 for (UInt_t iCls=0; iCls<nClasses; ++iCls) {
1797 ClassInfo *classInfo=DataInfo().GetClassInfo (iCls);
1798 TString className =classInfo->GetName();
1799 UInt_t classNumber=classInfo->GetNumber();
1800
1801 void* classNode=gTools().AddChild(classes, "Class");
1802 gTools().AddAttr( classNode, "Name", className );
1803 gTools().AddAttr( classNode, "Index", classNumber );
1804 }
1805}
1806////////////////////////////////////////////////////////////////////////////////
1807/// write target info to XML
1808
1809void TMVA::MethodBase::AddTargetsXMLTo( void* parent ) const
1810{
1811 void* targets = gTools().AddChild(parent, "Targets");
1812 gTools().AddAttr( targets, "NTrgt", gTools().StringFromInt(DataInfo().GetNTargets()) );
1813
1814 for (UInt_t idx=0; idx<DataInfo().GetTargetInfos().size(); idx++) {
1815 VariableInfo& vi = DataInfo().GetTargetInfos()[idx];
1816 void* tar = gTools().AddChild( targets, "Target" );
1817 gTools().AddAttr( tar, "TargetIndex", idx );
1818 vi.AddToXML( tar );
1819 }
1820}
1821
1822////////////////////////////////////////////////////////////////////////////////
1823/// read variable info from XML
1824
1826{
1827 UInt_t readNVar;
1828 gTools().ReadAttr( varnode, "NVar", readNVar);
1829
1830 if (readNVar!=DataInfo().GetNVariables()) {
1831 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
1832 << " while there are " << readNVar << " variables declared in the file"
1833 << Endl;
1834 }
1835
1836 // we want to make sure all variables are read in the order they are defined
1837 VariableInfo readVarInfo, existingVarInfo;
1838 int varIdx = 0;
1839 void* ch = gTools().GetChild(varnode);
1840 while (ch) {
1841 gTools().ReadAttr( ch, "VarIndex", varIdx);
1842 existingVarInfo = DataInfo().GetVariableInfos()[varIdx];
1843 readVarInfo.ReadFromXML(ch);
1844
1845 if (existingVarInfo.GetExpression() == readVarInfo.GetExpression()) {
1846 readVarInfo.SetExternalLink(existingVarInfo.GetExternalLink());
1847 existingVarInfo = readVarInfo;
1848 }
1849 else {
1850 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadVariablesFromXML>" << Endl;
1851 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
1852 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1853 Log() << kINFO << "correct working of the method):" << Endl;
1854 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << existingVarInfo.GetExpression() << Endl;
1855 Log() << kINFO << " var #" << varIdx <<" declared in file : " << readVarInfo.GetExpression() << Endl;
1856 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1857 }
1858 ch = gTools().GetNextChild(ch);
1859 }
1860}
1861
1862////////////////////////////////////////////////////////////////////////////////
1863/// read spectator info from XML
1864
1866{
1867 UInt_t readNSpec;
1868 gTools().ReadAttr( specnode, "NSpec", readNSpec);
1869
1870 if (readNSpec!=DataInfo().GetNSpectators(kFALSE)) {
1871 Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName()) << "You declared "<< DataInfo().GetNSpectators(kFALSE) << " spectators in the Reader"
1872 << " while there are " << readNSpec << " spectators declared in the file"
1873 << Endl;
1874 }
1875
1876 // we want to make sure all variables are read in the order they are defined
1877 VariableInfo readSpecInfo, existingSpecInfo;
1878 int specIdx = 0;
1879 void* ch = gTools().GetChild(specnode);
1880 while (ch) {
1881 gTools().ReadAttr( ch, "SpecIndex", specIdx);
1882 existingSpecInfo = DataInfo().GetSpectatorInfos()[specIdx];
1883 readSpecInfo.ReadFromXML(ch);
1884
1885 if (existingSpecInfo.GetExpression() == readSpecInfo.GetExpression()) {
1886 readSpecInfo.SetExternalLink(existingSpecInfo.GetExternalLink());
1887 existingSpecInfo = readSpecInfo;
1888 }
1889 else {
1890 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "ERROR in <ReadSpectatorsFromXML>" << Endl;
1891 Log() << kINFO << "The definition (or the order) of the spectators found in the input file is" << Endl;
1892 Log() << kINFO << "not the same as the one declared in the Reader (which is necessary for the" << Endl;
1893 Log() << kINFO << "correct working of the method):" << Endl;
1894 Log() << kINFO << " spec #" << specIdx <<" declared in Reader: " << existingSpecInfo.GetExpression() << Endl;
1895 Log() << kINFO << " spec #" << specIdx <<" declared in file : " << readSpecInfo.GetExpression() << Endl;
1896 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1897 }
1898 ch = gTools().GetNextChild(ch);
1899 }
1900}
1901
1902////////////////////////////////////////////////////////////////////////////////
1903/// read number of classes from XML
1904
1906{
1907 UInt_t readNCls;
1908 // coverity[tainted_data_argument]
1909 gTools().ReadAttr( clsnode, "NClass", readNCls);
1910
1911 TString className="";
1912 UInt_t classIndex=0;
1913 void* ch = gTools().GetChild(clsnode);
1914 if (!ch) {
1915 for (UInt_t icls = 0; icls<readNCls;++icls) {
1916 TString classname = Form("class%i",icls);
1917 DataInfo().AddClass(classname);
1918
1919 }
1920 }
1921 else{
1922 while (ch) {
1923 gTools().ReadAttr( ch, "Index", classIndex);
1924 gTools().ReadAttr( ch, "Name", className );
1925 DataInfo().AddClass(className);
1926
1927 ch = gTools().GetNextChild(ch);
1928 }
1929 }
1930
1931 // retrieve signal and background class index
1932 if (DataInfo().GetClassInfo("Signal") != 0) {
1933 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
1934 }
1935 else
1936 fSignalClass=0;
1937 if (DataInfo().GetClassInfo("Background") != 0) {
1938 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
1939 }
1940 else
1941 fBackgroundClass=1;
1942}
1943
1944////////////////////////////////////////////////////////////////////////////////
1945/// read target info from XML
1946
1948{
1949 UInt_t readNTar;
1950 gTools().ReadAttr( tarnode, "NTrgt", readNTar);
1951
1952 int tarIdx = 0;
1953 TString expression;
1954 void* ch = gTools().GetChild(tarnode);
1955 while (ch) {
1956 gTools().ReadAttr( ch, "TargetIndex", tarIdx);
1957 gTools().ReadAttr( ch, "Expression", expression);
1958 DataInfo().AddTarget(expression,"","",0,0);
1959
1960 ch = gTools().GetNextChild(ch);
1961 }
1962}
1963
1964////////////////////////////////////////////////////////////////////////////////
1965/// returns the ROOT directory where info/histograms etc of the
1966/// corresponding MVA method instance are stored
1967
1969{
1970 if (fBaseDir != 0) return fBaseDir;
1971 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodName() << " not set yet --> check if already there.." <<Endl;
1972
1973 if (IsSilentFile()) {
1974 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
1975 << "MethodBase::BaseDir() - No directory exists when running a Method without output file. Enable the "
1976 "output when creating the factory"
1977 << Endl;
1978 }
1979
1980 TDirectory* methodDir = MethodBaseDir();
1981 if (methodDir==0)
1982 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MethodBase::BaseDir() - MethodBaseDir() return a NULL pointer!" << Endl;
1983
1984 TString defaultDir = GetMethodName();
1985 TDirectory *sdir = methodDir->GetDirectory(defaultDir.Data());
1986 if(!sdir)
1987 {
1988 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " does not exist yet--> created it" <<Endl;
1989 sdir = methodDir->mkdir(defaultDir);
1990 sdir->cd();
1991 // write weight file name into target file
1992 if (fModelPersistence) {
1993 TObjString wfilePath( gSystem->WorkingDirectory() );
1994 TObjString wfileName( GetWeightFileName() );
1995 wfilePath.Write( "TrainingPath" );
1996 wfileName.Write( "WeightFileName" );
1997 }
1998 }
1999
2000 Log()<<kDEBUG<<Form("Dataset[%s] : ",DataInfo().GetName())<<" Base Directory for " << GetMethodTypeName() << " existed, return it.." <<Endl;
2001 return sdir;
2002}
2003
2004////////////////////////////////////////////////////////////////////////////////
2005/// returns the ROOT directory where all instances of the
2006/// corresponding MVA method are stored
2007
2009{
2010 if (fMethodBaseDir != 0) {
2011 return fMethodBaseDir;
2012 }
2013
2014 const char *datasetName = DataInfo().GetName();
2015
2016 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodTypeName()
2017 << " not set yet --> check if already there.." << Endl;
2018
2019 TDirectory *factoryBaseDir = GetFile();
2020 if (!factoryBaseDir) return nullptr;
2021 fMethodBaseDir = factoryBaseDir->GetDirectory(datasetName);
2022 if (!fMethodBaseDir) {
2023 fMethodBaseDir = factoryBaseDir->mkdir(datasetName, Form("Base directory for dataset %s", datasetName));
2024 if (!fMethodBaseDir) {
2025 Log() << kFATAL << "Can not create dir " << datasetName;
2026 }
2027 }
2028 TString methodTypeDir = Form("Method_%s", GetMethodTypeName().Data());
2029 fMethodBaseDir = fMethodBaseDir->GetDirectory(methodTypeDir.Data());
2030
2031 if (!fMethodBaseDir) {
2032 TDirectory *datasetDir = factoryBaseDir->GetDirectory(datasetName);
2033 TString methodTypeDirHelpStr = Form("Directory for all %s methods", GetMethodTypeName().Data());
2034 fMethodBaseDir = datasetDir->mkdir(methodTypeDir.Data(), methodTypeDirHelpStr);
2035 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName) << " Base Directory for " << GetMethodName()
2036 << " does not exist yet--> created it" << Endl;
2037 }
2038
2039 Log() << kDEBUG << Form("Dataset[%s] : ", datasetName)
2040 << "Return from MethodBaseDir() after creating base directory " << Endl;
2041 return fMethodBaseDir;
2042}
2043
2044////////////////////////////////////////////////////////////////////////////////
2045/// set directory of weight file
2046
2048{
2049 fFileDir = fileDir;
2050 gSystem->mkdir( fFileDir, kTRUE );
2051}
2052
2053////////////////////////////////////////////////////////////////////////////////
2054/// set the weight file name (depreciated)
2055
2057{
2058 fWeightFile = theWeightFile;
2059}
2060
2061////////////////////////////////////////////////////////////////////////////////
2062/// retrieve weight file name
2063
2065{
2066 if (fWeightFile!="") return fWeightFile;
2067
2068 // the default consists of
2069 // directory/jobname_methodname_suffix.extension.{root/txt}
2070 TString suffix = "";
2071 TString wFileDir(GetWeightFileDir());
2072 TString wFileName = GetJobName() + "_" + GetMethodName() +
2073 suffix + "." + gConfig().GetIONames().fWeightFileExtension + ".xml";
2074 if (wFileDir.IsNull() ) return wFileName;
2075 // add weight file directory of it is not null
2076 return ( wFileDir + (wFileDir[wFileDir.Length()-1]=='/' ? "" : "/")
2077 + wFileName );
2078}
2079////////////////////////////////////////////////////////////////////////////////
2080/// writes all MVA evaluation histograms to file
2081
2083{
2084 BaseDir()->cd();
2085
2086
2087 // write MVA PDFs to file - if exist
2088 if (0 != fMVAPdfS) {
2089 fMVAPdfS->GetOriginalHist()->Write();
2090 fMVAPdfS->GetSmoothedHist()->Write();
2091 fMVAPdfS->GetPDFHist()->Write();
2092 }
2093 if (0 != fMVAPdfB) {
2094 fMVAPdfB->GetOriginalHist()->Write();
2095 fMVAPdfB->GetSmoothedHist()->Write();
2096 fMVAPdfB->GetPDFHist()->Write();
2097 }
2098
2099 // write result-histograms
2100 Results* results = Data()->GetResults( GetMethodName(), treetype, Types::kMaxAnalysisType );
2101 if (!results)
2102 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<WriteEvaluationHistosToFile> Unknown result: "
2103 << GetMethodName() << (treetype==Types::kTraining?"/kTraining":"/kTesting")
2104 << "/kMaxAnalysisType" << Endl;
2105 results->GetStorage()->Write();
2106 if (treetype==Types::kTesting) {
2107 // skipping plotting of variables if too many (default is 200)
2108 if ((int) DataInfo().GetNVariables()< gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables)
2109 GetTransformationHandler().PlotVariables (GetEventCollection( Types::kTesting ), BaseDir() );
2110 else
2111 Log() << kINFO << TString::Format("Dataset[%s] : ",DataInfo().GetName())
2112 << " variable plots are not produces ! The number of variables is " << DataInfo().GetNVariables()
2113 << " , it is larger than " << gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables << Endl;
2114 }
2115}
2116
2117////////////////////////////////////////////////////////////////////////////////
2118/// write special monitoring histograms to file
2119/// dummy implementation here -----------------
2120
2122{
2123}
2124
2125////////////////////////////////////////////////////////////////////////////////
2126/// reads one line from the input stream
2127/// checks for certain keywords and interprets
2128/// the line if keywords are found
2129
2130Bool_t TMVA::MethodBase::GetLine(std::istream& fin, char* buf )
2131{
2132 fin.getline(buf,512);
2133 TString line(buf);
2134 if (line.BeginsWith("TMVA Release")) {
2135 Ssiz_t start = line.First('[')+1;
2136 Ssiz_t length = line.Index("]",start)-start;
2137 TString code = line(start,length);
2138 std::stringstream s(code.Data());
2139 s >> fTMVATrainingVersion;
2140 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
2141 }
2142 if (line.BeginsWith("ROOT Release")) {
2143 Ssiz_t start = line.First('[')+1;
2144 Ssiz_t length = line.Index("]",start)-start;
2145 TString code = line(start,length);
2146 std::stringstream s(code.Data());
2147 s >> fROOTTrainingVersion;
2148 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
2149 }
2150 if (line.BeginsWith("Analysis type")) {
2151 Ssiz_t start = line.First('[')+1;
2152 Ssiz_t length = line.Index("]",start)-start;
2153 TString code = line(start,length);
2154 std::stringstream s(code.Data());
2155 std::string analysisType;
2156 s >> analysisType;
2157 if (analysisType == "regression" || analysisType == "Regression") SetAnalysisType( Types::kRegression );
2158 else if (analysisType == "classification" || analysisType == "Classification") SetAnalysisType( Types::kClassification );
2159 else if (analysisType == "multiclass" || analysisType == "Multiclass") SetAnalysisType( Types::kMulticlass );
2160 else Log() << kFATAL << "Analysis type " << analysisType << " from weight-file not known!" << std::endl;
2161
2162 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Method was trained for "
2163 << (GetAnalysisType() == Types::kRegression ? "Regression" :
2164 (GetAnalysisType() == Types::kMulticlass ? "Multiclass" : "Classification")) << Endl;
2165 }
2166
2167 return true;
2168}
2169
2170////////////////////////////////////////////////////////////////////////////////
2171/// Create PDFs of the MVA output variables
2172
2174{
2175 Data()->SetCurrentType(Types::kTraining);
2176
2177 // the PDF's are stored as results ONLY if the corresponding "results" are booked,
2178 // otherwise they will be only used 'online'
2179 ResultsClassification * mvaRes = dynamic_cast<ResultsClassification*>
2180 ( Data()->GetResults(GetMethodName(), Types::kTraining, Types::kClassification) );
2181
2182 if (mvaRes==0 || mvaRes->GetSize()==0) {
2183 Log() << kERROR<<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CreateMVAPdfs> No result of classifier testing available" << Endl;
2184 }
2185
2186 Double_t minVal = *std::min_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2187 Double_t maxVal = *std::max_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2188
2189 // create histograms that serve as basis to create the MVA Pdfs
2190 TH1* histMVAPdfS = new TH1D( GetMethodTypeName() + "_tr_S", GetMethodTypeName() + "_tr_S",
2191 fMVAPdfS->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2192 TH1* histMVAPdfB = new TH1D( GetMethodTypeName() + "_tr_B", GetMethodTypeName() + "_tr_B",
2193 fMVAPdfB->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2194
2195
2196 // compute sum of weights properly
2197 histMVAPdfS->Sumw2();
2198 histMVAPdfB->Sumw2();
2199
2200 // fill histograms
2201 for (UInt_t ievt=0; ievt<mvaRes->GetSize(); ievt++) {
2202 Double_t theVal = mvaRes->GetValueVector()->at(ievt);
2203 Double_t theWeight = Data()->GetEvent(ievt)->GetWeight();
2204
2205 if (DataInfo().IsSignal(Data()->GetEvent(ievt))) histMVAPdfS->Fill( theVal, theWeight );
2206 else histMVAPdfB->Fill( theVal, theWeight );
2207 }
2208
2209 gTools().NormHist( histMVAPdfS );
2210 gTools().NormHist( histMVAPdfB );
2211
2212 // momentary hack for ROOT problem
2213 if(!IsSilentFile())
2214 {
2215 histMVAPdfS->Write();
2216 histMVAPdfB->Write();
2217 }
2218 // create PDFs
2219 fMVAPdfS->BuildPDF ( histMVAPdfS );
2220 fMVAPdfB->BuildPDF ( histMVAPdfB );
2221 fMVAPdfS->ValidatePDF( histMVAPdfS );
2222 fMVAPdfB->ValidatePDF( histMVAPdfB );
2223
2224 if (DataInfo().GetNClasses() == 2) { // TODO: this is an ugly hack.. adapt this to new framework
2225 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())
2226 << Form( "<CreateMVAPdfs> Separation from histogram (PDF): %1.3f (%1.3f)",
2227 GetSeparation( histMVAPdfS, histMVAPdfB ), GetSeparation( fMVAPdfS, fMVAPdfB ) )
2228 << Endl;
2229 }
2230
2231 delete histMVAPdfS;
2232 delete histMVAPdfB;
2233}
2234
2236 // the simple one, automatically calculates the mvaVal and uses the
2237 // SAME sig/bkg ratio as given in the training sample (typically 50/50
2238 // .. (NormMode=EqualNumEvents) but can be different)
2239 if (!fMVAPdfS || !fMVAPdfB) {
2240 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "<GetProba> MVA PDFs for Signal and Background don't exist yet, we'll create them on demand" << Endl;
2241 CreateMVAPdfs();
2242 }
2243 Double_t sigFraction = DataInfo().GetTrainingSumSignalWeights() / (DataInfo().GetTrainingSumSignalWeights() + DataInfo().GetTrainingSumBackgrWeights() );
2244 Double_t mvaVal = GetMvaValue(ev);
2245
2246 return GetProba(mvaVal,sigFraction);
2247
2248}
2249////////////////////////////////////////////////////////////////////////////////
2250/// compute likelihood ratio
2251
2253{
2254 if (!fMVAPdfS || !fMVAPdfB) {
2255 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetProba> MVA PDFs for Signal and Background don't exist" << Endl;
2256 return -1.0;
2257 }
2258 Double_t p_s = fMVAPdfS->GetVal( mvaVal );
2259 Double_t p_b = fMVAPdfB->GetVal( mvaVal );
2260
2261 Double_t denom = p_s*ap_sig + p_b*(1 - ap_sig);
2262
2263 return (denom > 0) ? (p_s*ap_sig) / denom : -1;
2264}
2265
2266////////////////////////////////////////////////////////////////////////////////
2267/// compute rarity:
2268/// \f[
2269/// R(x) = \int_{[-\infty..x]} { PDF(x') dx' }
2270/// \f]
2271/// where PDF(x) is the PDF of the classifier's signal or background distribution
2272
2274{
2275 if ((reftype == Types::kSignal && !fMVAPdfS) || (reftype == Types::kBackground && !fMVAPdfB)) {
2276 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetRarity> Required MVA PDF for Signal or Background does not exist: "
2277 << "select option \"CreateMVAPdfs\"" << Endl;
2278 return 0.0;
2279 }
2280
2281 PDF* thePdf = ((reftype == Types::kSignal) ? fMVAPdfS : fMVAPdfB);
2282
2283 return thePdf->GetIntegral( thePdf->GetXmin(), mvaVal );
2284}
2285
2286////////////////////////////////////////////////////////////////////////////////
2287/// fill background efficiency (resp. rejection) versus signal efficiency plots
2288/// returns signal efficiency at background efficiency indicated in theString
2289
2291{
2292 Data()->SetCurrentType(type);
2293 Results* results = Data()->GetResults( GetMethodName(), type, Types::kClassification );
2294 std::vector<Float_t>* mvaRes = dynamic_cast<ResultsClassification*>(results)->GetValueVector();
2295
2296 // parse input string for required background efficiency
2297 TList* list = gTools().ParseFormatLine( theString );
2298
2299 // sanity check
2300 Bool_t computeArea = kFALSE;
2301 if (!list || list->GetSize() < 2) computeArea = kTRUE; // the area is computed
2302 else if (list->GetSize() > 2) {
2303 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Wrong number of arguments"
2304 << " in string: " << theString
2305 << " | required format, e.g., Efficiency:0.05, or empty string" << Endl;
2306 delete list;
2307 return -1;
2308 }
2309
2310 // sanity check
2311 if ( results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2312 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2313 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Binning mismatch between signal and background histos" << Endl;
2314 delete list;
2315 return -1.0;
2316 }
2317
2318 // create histograms
2319
2320 // first, get efficiency histograms for signal and background
2321 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2322 Double_t xmin = effhist->GetXaxis()->GetXmin();
2323 Double_t xmax = effhist->GetXaxis()->GetXmax();
2324
2325 TTHREAD_TLS(Double_t) nevtS;
2326
2327 // first round ? --> create histograms
2328 if (results->DoesExist("MVA_EFF_S")==0) {
2329
2330 // for efficiency plot
2331 TH1* eff_s = new TH1D( GetTestvarName() + "_effS", GetTestvarName() + " (signal)", fNbinsH, xmin, xmax );
2332 TH1* eff_b = new TH1D( GetTestvarName() + "_effB", GetTestvarName() + " (background)", fNbinsH, xmin, xmax );
2333 results->Store(eff_s, "MVA_EFF_S");
2334 results->Store(eff_b, "MVA_EFF_B");
2335
2336 // sign if cut
2337 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2338
2339 // this method is unbinned
2340 nevtS = 0;
2341 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2342
2343 // read the tree
2344 Bool_t isSignal = DataInfo().IsSignal(GetEvent(ievt));
2345 Float_t theWeight = GetEvent(ievt)->GetWeight();
2346 Float_t theVal = (*mvaRes)[ievt];
2347
2348 // select histogram depending on if sig or bgd
2349 TH1* theHist = isSignal ? eff_s : eff_b;
2350
2351 // count signal and background events in tree
2352 if (isSignal) nevtS+=theWeight;
2353
2354 TAxis* axis = theHist->GetXaxis();
2355 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2356 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2357 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2358 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2359 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2360
2361 if (sign > 0)
2362 for (Int_t ibin=1; ibin<=maxbin; ibin++) theHist->AddBinContent( ibin , theWeight);
2363 else if (sign < 0)
2364 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theHist->AddBinContent( ibin , theWeight );
2365 else
2366 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetEfficiency> Mismatch in sign" << Endl;
2367 }
2368
2369 // renormalise maximum to <=1
2370 // eff_s->Scale( 1.0/TMath::Max(1.,eff_s->GetMaximum()) );
2371 // eff_b->Scale( 1.0/TMath::Max(1.,eff_b->GetMaximum()) );
2372
2375
2376 // background efficiency versus signal efficiency
2377 TH1* eff_BvsS = new TH1D( GetTestvarName() + "_effBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2378 results->Store(eff_BvsS, "MVA_EFF_BvsS");
2379 eff_BvsS->SetXTitle( "Signal eff" );
2380 eff_BvsS->SetYTitle( "Backgr eff" );
2381
2382 // background rejection (=1-eff.) versus signal efficiency
2383 TH1* rej_BvsS = new TH1D( GetTestvarName() + "_rejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2384 results->Store(rej_BvsS);
2385 rej_BvsS->SetXTitle( "Signal eff" );
2386 rej_BvsS->SetYTitle( "Backgr rejection (1-eff)" );
2387
2388 // inverse background eff (1/eff.) versus signal efficiency
2389 TH1* inveff_BvsS = new TH1D( GetTestvarName() + "_invBeffvsSeff",
2390 GetTestvarName(), fNbins, 0, 1 );
2391 results->Store(inveff_BvsS);
2392 inveff_BvsS->SetXTitle( "Signal eff" );
2393 inveff_BvsS->SetYTitle( "Inverse backgr. eff (1/eff)" );
2394
2395 // use root finder
2396 // spline background efficiency plot
2397 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2399 fSplRefS = new TSpline1( "spline2_signal", new TGraph( eff_s ) );
2400 fSplRefB = new TSpline1( "spline2_background", new TGraph( eff_b ) );
2401
2402 // verify spline sanity
2403 gTools().CheckSplines( eff_s, fSplRefS );
2404 gTools().CheckSplines( eff_b, fSplRefB );
2405 }
2406
2407 // make the background-vs-signal efficiency plot
2408
2409 // create root finder
2410 RootFinder rootFinder( this, fXmin, fXmax );
2411
2412 Double_t effB = 0;
2413 fEffS = eff_s; // to be set for the root finder
2414 for (Int_t bini=1; bini<=fNbins; bini++) {
2415
2416 // find cut value corresponding to a given signal efficiency
2417 Double_t effS = eff_BvsS->GetBinCenter( bini );
2418 Double_t cut = rootFinder.Root( effS );
2419
2420 // retrieve background efficiency for given cut
2421 if (Use_Splines_for_Eff_) effB = fSplRefB->Eval( cut );
2422 else effB = eff_b->GetBinContent( eff_b->FindBin( cut ) );
2423
2424 // and fill histograms
2425 eff_BvsS->SetBinContent( bini, effB );
2426 rej_BvsS->SetBinContent( bini, 1.0-effB );
2428 inveff_BvsS->SetBinContent( bini, 1.0/effB );
2429 }
2430
2431 // create splines for histogram
2432 fSpleffBvsS = new TSpline1( "effBvsS", new TGraph( eff_BvsS ) );
2433
2434 // search for overlap point where, when cutting on it,
2435 // one would obtain: eff_S = rej_B = 1 - eff_B
2436 Double_t effS = 0., rejB, effS_ = 0., rejB_ = 0.;
2437 Int_t nbins_ = 5000;
2438 for (Int_t bini=1; bini<=nbins_; bini++) {
2439
2440 // get corresponding signal and background efficiencies
2441 effS = (bini - 0.5)/Float_t(nbins_);
2442 rejB = 1.0 - fSpleffBvsS->Eval( effS );
2443
2444 // find signal efficiency that corresponds to required background efficiency
2445 if ((effS - rejB)*(effS_ - rejB_) < 0) break;
2446 effS_ = effS;
2447 rejB_ = rejB;
2448 }
2449
2450 // find cut that corresponds to signal efficiency and update signal-like criterion
2451 Double_t cut = rootFinder.Root( 0.5*(effS + effS_) );
2452 SetSignalReferenceCut( cut );
2453 fEffS = 0;
2454 }
2455
2456 // must exist...
2457 if (0 == fSpleffBvsS) {
2458 delete list;
2459 return 0.0;
2460 }
2461
2462 // now find signal efficiency that corresponds to required background efficiency
2463 Double_t effS = 0, effB = 0, effS_ = 0, effB_ = 0;
2464 Int_t nbins_ = 1000;
2465
2466 if (computeArea) {
2467
2468 // compute area of rej-vs-eff plot
2469 Double_t integral = 0;
2470 for (Int_t bini=1; bini<=nbins_; bini++) {
2471
2472 // get corresponding signal and background efficiencies
2473 effS = (bini - 0.5)/Float_t(nbins_);
2474 effB = fSpleffBvsS->Eval( effS );
2475 integral += (1.0 - effB);
2476 }
2477 integral /= nbins_;
2478
2479 delete list;
2480 return integral;
2481 }
2482 else {
2483
2484 // that will be the value of the efficiency retured (does not affect
2485 // the efficiency-vs-bkg plot which is done anyway.
2486 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2487
2488 // find precise efficiency value
2489 for (Int_t bini=1; bini<=nbins_; bini++) {
2490
2491 // get corresponding signal and background efficiencies
2492 effS = (bini - 0.5)/Float_t(nbins_);
2493 effB = fSpleffBvsS->Eval( effS );
2494
2495 // find signal efficiency that corresponds to required background efficiency
2496 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2497 effS_ = effS;
2498 effB_ = effB;
2499 }
2500
2501 // take mean between bin above and bin below
2502 effS = 0.5*(effS + effS_);
2503
2504 effSerr = 0;
2505 if (nevtS > 0) effSerr = TMath::Sqrt( effS*(1.0 - effS)/nevtS );
2506
2507 delete list;
2508 return effS;
2509 }
2510
2511 return -1;
2512}
2513
2514////////////////////////////////////////////////////////////////////////////////
2515
2517{
2518 Data()->SetCurrentType(Types::kTraining);
2519
2520 Results* results = Data()->GetResults(GetMethodName(), Types::kTesting, Types::kNoAnalysisType);
2521
2522 // fill background efficiency (resp. rejection) versus signal efficiency plots
2523 // returns signal efficiency at background efficiency indicated in theString
2524
2525 // parse input string for required background efficiency
2526 TList* list = gTools().ParseFormatLine( theString );
2527 // sanity check
2528
2529 if (list->GetSize() != 2) {
2530 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Wrong number of arguments"
2531 << " in string: " << theString
2532 << " | required format, e.g., Efficiency:0.05" << Endl;
2533 delete list;
2534 return -1;
2535 }
2536 // that will be the value of the efficiency retured (does not affect
2537 // the efficiency-vs-bkg plot which is done anyway.
2538 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2539
2540 delete list;
2541
2542 // sanity check
2543 if (results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
2544 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
2545 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetTrainingEfficiency> Binning mismatch between signal and background histos"
2546 << Endl;
2547 return -1.0;
2548 }
2549
2550 // create histogram
2551
2552 // first, get efficiency histograms for signal and background
2553 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
2554 Double_t xmin = effhist->GetXaxis()->GetXmin();
2555 Double_t xmax = effhist->GetXaxis()->GetXmax();
2556
2557 // first round ? --> create and fill histograms
2558 if (results->DoesExist("MVA_TRAIN_S")==0) {
2559
2560 // classifier response distributions for test sample
2561 Double_t sxmax = fXmax+0.00001;
2562
2563 // MVA plots on the training sample (check for overtraining)
2564 TH1* mva_s_tr = new TH1D( GetTestvarName() + "_Train_S",GetTestvarName() + "_Train_S", fNbinsMVAoutput, fXmin, sxmax );
2565 TH1* mva_b_tr = new TH1D( GetTestvarName() + "_Train_B",GetTestvarName() + "_Train_B", fNbinsMVAoutput, fXmin, sxmax );
2566 results->Store(mva_s_tr, "MVA_TRAIN_S");
2567 results->Store(mva_b_tr, "MVA_TRAIN_B");
2568 mva_s_tr->Sumw2();
2569 mva_b_tr->Sumw2();
2570
2571 // Training efficiency plots
2572 TH1* mva_eff_tr_s = new TH1D( GetTestvarName() + "_trainingEffS", GetTestvarName() + " (signal)",
2573 fNbinsH, xmin, xmax );
2574 TH1* mva_eff_tr_b = new TH1D( GetTestvarName() + "_trainingEffB", GetTestvarName() + " (background)",
2575 fNbinsH, xmin, xmax );
2576 results->Store(mva_eff_tr_s, "MVA_TRAINEFF_S");
2577 results->Store(mva_eff_tr_b, "MVA_TRAINEFF_B");
2578
2579 // sign if cut
2580 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2581
2582 std::vector<Double_t> mvaValues = GetMvaValues(0,Data()->GetNEvents());
2583 assert( (Long64_t) mvaValues.size() == Data()->GetNEvents());
2584
2585 // this method is unbinned
2586 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2587
2588 Data()->SetCurrentEvent(ievt);
2589 const Event* ev = GetEvent();
2590
2591 Double_t theVal = mvaValues[ievt];
2592 Double_t theWeight = ev->GetWeight();
2593
2594 TH1* theEffHist = DataInfo().IsSignal(ev) ? mva_eff_tr_s : mva_eff_tr_b;
2595 TH1* theClsHist = DataInfo().IsSignal(ev) ? mva_s_tr : mva_b_tr;
2596
2597 theClsHist->Fill( theVal, theWeight );
2598
2599 TAxis* axis = theEffHist->GetXaxis();
2600 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2601 if (sign > 0 && maxbin > fNbinsH) continue; // can happen... event doesn't count
2602 if (sign < 0 && maxbin < 1 ) continue; // can happen... event doesn't count
2603 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2604 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2605
2606 if (sign > 0) for (Int_t ibin=1; ibin<=maxbin; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2607 else for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2608 }
2609
2610 // normalise output distributions
2611 // uncomment those (and several others if you want unnormalized output
2612 gTools().NormHist( mva_s_tr );
2613 gTools().NormHist( mva_b_tr );
2614
2615 // renormalise to maximum
2616 mva_eff_tr_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_s->GetMaximum()) );
2617 mva_eff_tr_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_b->GetMaximum()) );
2618
2619 // Training background efficiency versus signal efficiency
2620 TH1* eff_bvss = new TH1D( GetTestvarName() + "_trainingEffBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2621 // Training background rejection (=1-eff.) versus signal efficiency
2622 TH1* rej_bvss = new TH1D( GetTestvarName() + "_trainingRejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
2623 results->Store(eff_bvss, "EFF_BVSS_TR");
2624 results->Store(rej_bvss, "REJ_BVSS_TR");
2625
2626 // use root finder
2627 // spline background efficiency plot
2628 // note that there is a bin shift when going from a TH1D object to a TGraph :-(
2630 if (fSplTrainRefS) delete fSplTrainRefS;
2631 if (fSplTrainRefB) delete fSplTrainRefB;
2632 fSplTrainRefS = new TSpline1( "spline2_signal", new TGraph( mva_eff_tr_s ) );
2633 fSplTrainRefB = new TSpline1( "spline2_background", new TGraph( mva_eff_tr_b ) );
2634
2635 // verify spline sanity
2636 gTools().CheckSplines( mva_eff_tr_s, fSplTrainRefS );
2637 gTools().CheckSplines( mva_eff_tr_b, fSplTrainRefB );
2638 }
2639
2640 // make the background-vs-signal efficiency plot
2641
2642 // create root finder
2643 RootFinder rootFinder(this, fXmin, fXmax );
2644
2645 Double_t effB = 0;
2646 fEffS = results->GetHist("MVA_TRAINEFF_S");
2647 for (Int_t bini=1; bini<=fNbins; bini++) {
2648
2649 // find cut value corresponding to a given signal efficiency
2650 Double_t effS = eff_bvss->GetBinCenter( bini );
2651
2652 Double_t cut = rootFinder.Root( effS );
2653
2654 // retrieve background efficiency for given cut
2655 if (Use_Splines_for_Eff_) effB = fSplTrainRefB->Eval( cut );
2656 else effB = mva_eff_tr_b->GetBinContent( mva_eff_tr_b->FindBin( cut ) );
2657
2658 // and fill histograms
2659 eff_bvss->SetBinContent( bini, effB );
2660 rej_bvss->SetBinContent( bini, 1.0-effB );
2661 }
2662 fEffS = 0;
2663
2664 // create splines for histogram
2665 fSplTrainEffBvsS = new TSpline1( "effBvsS", new TGraph( eff_bvss ) );
2666 }
2667
2668 // must exist...
2669 if (0 == fSplTrainEffBvsS) return 0.0;
2670
2671 // now find signal efficiency that corresponds to required background efficiency
2672 Double_t effS = 0., effB, effS_ = 0., effB_ = 0.;
2673 Int_t nbins_ = 1000;
2674 for (Int_t bini=1; bini<=nbins_; bini++) {
2675
2676 // get corresponding signal and background efficiencies
2677 effS = (bini - 0.5)/Float_t(nbins_);
2678 effB = fSplTrainEffBvsS->Eval( effS );
2679
2680 // find signal efficiency that corresponds to required background efficiency
2681 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
2682 effS_ = effS;
2683 effB_ = effB;
2684 }
2685
2686 return 0.5*(effS + effS_); // the mean between bin above and bin below
2687}
2688
2689////////////////////////////////////////////////////////////////////////////////
2690
2691std::vector<Float_t> TMVA::MethodBase::GetMulticlassEfficiency(std::vector<std::vector<Float_t> >& purity)
2692{
2693 Data()->SetCurrentType(Types::kTesting);
2694 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
2695 if (!resMulticlass) Log() << kFATAL<<Form("Dataset[%s] : ",DataInfo().GetName())<< "unable to create pointer in GetMulticlassEfficiency, exiting."<<Endl;
2696
2697 purity.push_back(resMulticlass->GetAchievablePur());
2698 return resMulticlass->GetAchievableEff();
2699}
2700
2701////////////////////////////////////////////////////////////////////////////////
2702
2703std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity)
2704{
2705 Data()->SetCurrentType(Types::kTraining);
2706 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTraining, Types::kMulticlass));
2707 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in GetMulticlassTrainingEfficiency, exiting."<<Endl;
2708
2709 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Determine optimal multiclass cuts for training data..." << Endl;
2710 for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
2711 resMulticlass->GetBestMultiClassCuts(icls);
2712 }
2713
2714 purity.push_back(resMulticlass->GetAchievablePur());
2715 return resMulticlass->GetAchievableEff();
2716}
2717
2718////////////////////////////////////////////////////////////////////////////////
2719/// Construct a confusion matrix for a multiclass classifier. The confusion
2720/// matrix compares, in turn, each class agaist all other classes in a pair-wise
2721/// fashion. In rows with index \f$ k_r = 0 ... K \f$, \f$ k_r \f$ is
2722/// considered signal for the sake of comparison and for each column
2723/// \f$ k_c = 0 ... K \f$ the corresponding class is considered background.
2724///
2725/// Note that the diagonal elements will be returned as NaN since this will
2726/// compare a class against itself.
2727///
2728/// \see TMVA::ResultsMulticlass::GetConfusionMatrix
2729///
2730/// \param[in] effB The background efficiency for which to evaluate.
2731/// \param[in] type The data set on which to evaluate (training, testing ...).
2732///
2733/// \return A matrix containing signal efficiencies for the given background
2734/// efficiency. The diagonal elements are NaN since this measure is
2735/// meaningless (comparing a class against itself).
2736///
2737
2739{
2740 if (GetAnalysisType() != Types::kMulticlass) {
2741 Log() << kFATAL << "Cannot get confusion matrix for non-multiclass analysis." << std::endl;
2742 return TMatrixD(0, 0);
2743 }
2744
2745 Data()->SetCurrentType(type);
2746 ResultsMulticlass *resMulticlass =
2747 dynamic_cast<ResultsMulticlass *>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
2748
2749 if (resMulticlass == nullptr) {
2750 Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
2751 << "unable to create pointer in GetMulticlassEfficiency, exiting." << Endl;
2752 return TMatrixD(0, 0);
2753 }
2754
2755 return resMulticlass->GetConfusionMatrix(effB);
2756}
2757
2758////////////////////////////////////////////////////////////////////////////////
2759/// compute significance of mean difference
2760/// \f[
2761/// significance = \frac{|<S> - <B>|}{\sqrt{RMS_{S2} + RMS_{B2}}}
2762/// \f]
2763
2765{
2766 Double_t rms = sqrt( fRmsS*fRmsS + fRmsB*fRmsB );
2767
2768 return (rms > 0) ? TMath::Abs(fMeanS - fMeanB)/rms : 0;
2769}
2770
2771////////////////////////////////////////////////////////////////////////////////
2772/// compute "separation" defined as
2773/// \f[
2774/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2775/// \f]
2776
2778{
2779 return gTools().GetSeparation( histoS, histoB );
2780}
2781
2782////////////////////////////////////////////////////////////////////////////////
2783/// compute "separation" defined as
2784/// \f[
2785/// <s2> = \frac{1}{2} \int_{-\infty}^{+\infty} { \frac{(S(x) - B(x))^2}{(S(x) + B(x))} dx }
2786/// \f]
2787
2789{
2790 // note, if zero pointers given, use internal pdf
2791 // sanity check first
2792 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2793 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2794 if (!pdfS) pdfS = fSplS;
2795 if (!pdfB) pdfB = fSplB;
2796
2797 if (!fSplS || !fSplB) {
2798 Log()<<kDEBUG<<Form("[%s] : ",DataInfo().GetName())<< "could not calculate the separation, distributions"
2799 << " fSplS or fSplB are not yet filled" << Endl;
2800 return 0;
2801 }else{
2802 return gTools().GetSeparation( *pdfS, *pdfB );
2803 }
2804}
2805
2806////////////////////////////////////////////////////////////////////////////////
2807/// calculate the area (integral) under the ROC curve as a
2808/// overall quality measure of the classification
2809
2811{
2812 // note, if zero pointers given, use internal pdf
2813 // sanity check first
2814 if ((!histS && histB) || (histS && !histB))
2815 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetROCIntegral(TH1D*, TH1D*)> Mismatch in hists" << Endl;
2816
2817 if (histS==0 || histB==0) return 0.;
2818
2819 TMVA::PDF *pdfS = new TMVA::PDF( " PDF Sig", histS, TMVA::PDF::kSpline3 );
2820 TMVA::PDF *pdfB = new TMVA::PDF( " PDF Bkg", histB, TMVA::PDF::kSpline3 );
2821
2822
2823 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2824 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2825
2826 Double_t integral = 0;
2827 UInt_t nsteps = 1000;
2828 Double_t step = (xmax-xmin)/Double_t(nsteps);
2829 Double_t cut = xmin;
2830 for (UInt_t i=0; i<nsteps; i++) {
2831 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2832 cut+=step;
2833 }
2834 delete pdfS;
2835 delete pdfB;
2836 return integral*step;
2837}
2838
2839
2840////////////////////////////////////////////////////////////////////////////////
2841/// calculate the area (integral) under the ROC curve as a
2842/// overall quality measure of the classification
2843
2845{
2846 // note, if zero pointers given, use internal pdf
2847 // sanity check first
2848 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2849 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetSeparation> Mismatch in pdfs" << Endl;
2850 if (!pdfS) pdfS = fSplS;
2851 if (!pdfB) pdfB = fSplB;
2852
2853 if (pdfS==0 || pdfB==0) return 0.;
2854
2855 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2856 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2857
2858 Double_t integral = 0;
2859 UInt_t nsteps = 1000;
2860 Double_t step = (xmax-xmin)/Double_t(nsteps);
2861 Double_t cut = xmin;
2862 for (UInt_t i=0; i<nsteps; i++) {
2863 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2864 cut+=step;
2865 }
2866 return integral*step;
2867}
2868
2869////////////////////////////////////////////////////////////////////////////////
2870/// plot significance, \f$ \frac{S}{\sqrt{S^2 + B^2}} \f$, curve for given number
2871/// of signal and background events; returns cut for maximum significance
2872/// also returned via reference is the maximum significance
2873
2875 Double_t BackgroundEvents,
2876 Double_t& max_significance_value ) const
2877{
2878 Results* results = Data()->GetResults( GetMethodName(), Types::kTesting, Types::kMaxAnalysisType );
2879
2880 Double_t max_significance(0);
2881 Double_t effS(0),effB(0),significance(0);
2882 TH1D *temp_histogram = new TH1D("temp", "temp", fNbinsH, fXmin, fXmax );
2883
2884 if (SignalEvents <= 0 || BackgroundEvents <= 0) {
2885 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<GetMaximumSignificance> "
2886 << "Number of signal or background events is <= 0 ==> abort"
2887 << Endl;
2888 }
2889
2890 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Using ratio SignalEvents/BackgroundEvents = "
2891 << SignalEvents/BackgroundEvents << Endl;
2892
2893 TH1* eff_s = results->GetHist("MVA_EFF_S");
2894 TH1* eff_b = results->GetHist("MVA_EFF_B");
2895
2896 if ( (eff_s==0) || (eff_b==0) ) {
2897 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Efficiency histograms empty !" << Endl;
2898 Log() << kWARNING <<Form("Dataset[%s] : ",DataInfo().GetName())<< "no maximum cut found, return 0" << Endl;
2899 return 0;
2900 }
2901
2902 for (Int_t bin=1; bin<=fNbinsH; bin++) {
2903 effS = eff_s->GetBinContent( bin );
2904 effB = eff_b->GetBinContent( bin );
2905
2906 // put significance into a histogram
2907 significance = sqrt(SignalEvents)*( effS )/sqrt( effS + ( BackgroundEvents / SignalEvents) * effB );
2908
2909 temp_histogram->SetBinContent(bin,significance);
2910 }
2911
2912 // find maximum in histogram
2913 max_significance = temp_histogram->GetBinCenter( temp_histogram->GetMaximumBin() );
2914 max_significance_value = temp_histogram->GetBinContent( temp_histogram->GetMaximumBin() );
2915
2916 // delete
2917 delete temp_histogram;
2918
2919 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Optimal cut at : " << max_significance << Endl;
2920 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName()) << "Maximum significance: " << max_significance_value << Endl;
2921
2922 return max_significance;
2923}
2924
2925////////////////////////////////////////////////////////////////////////////////
2926/// calculates rms,mean, xmin, xmax of the event variable
2927/// this can be either done for the variables as they are or for
2928/// normalised variables (in the range of 0-1) if "norm" is set to kTRUE
2929
2931 Double_t& meanS, Double_t& meanB,
2932 Double_t& rmsS, Double_t& rmsB,
2934{
2935 Types::ETreeType previousTreeType = Data()->GetCurrentType();
2936 Data()->SetCurrentType(treeType);
2937
2938 Long64_t entries = Data()->GetNEvents();
2939
2940 // sanity check
2941 if (entries <=0)
2942 Log() << kFATAL <<Form("Dataset[%s] : ",DataInfo().GetName())<< "<CalculateEstimator> Wrong tree type: " << treeType << Endl;
2943
2944 // index of the wanted variable
2945 UInt_t varIndex = DataInfo().FindVarIndex( theVarName );
2946
2947 // first fill signal and background in arrays before analysis
2948 xmin = +DBL_MAX;
2949 xmax = -DBL_MAX;
2950 Long64_t nEventsS = -1;
2951 Long64_t nEventsB = -1;
2952
2953 // take into account event weights
2954 meanS = 0;
2955 meanB = 0;
2956 rmsS = 0;
2957 rmsB = 0;
2958 Double_t sumwS = 0, sumwB = 0;
2959
2960 // loop over all training events
2961 for (Int_t ievt = 0; ievt < entries; ievt++) {
2962
2963 const Event* ev = GetEvent(ievt);
2964
2965 Double_t theVar = ev->GetValue(varIndex);
2966 Double_t weight = ev->GetWeight();
2967
2968 if (DataInfo().IsSignal(ev)) {
2969 sumwS += weight;
2970 meanS += weight*theVar;
2971 rmsS += weight*theVar*theVar;
2972 }
2973 else {
2974 sumwB += weight;
2975 meanB += weight*theVar;
2976 rmsB += weight*theVar*theVar;
2977 }
2978 xmin = TMath::Min( xmin, theVar );
2979 xmax = TMath::Max( xmax, theVar );
2980 }
2981 ++nEventsS;
2982 ++nEventsB;
2983
2984 meanS = meanS/sumwS;
2985 meanB = meanB/sumwB;
2986 rmsS = TMath::Sqrt( rmsS/sumwS - meanS*meanS );
2987 rmsB = TMath::Sqrt( rmsB/sumwB - meanB*meanB );
2988
2989 Data()->SetCurrentType(previousTreeType);
2990}
2991
2992////////////////////////////////////////////////////////////////////////////////
2993/// create reader class for method (classification only at present)
2994
2995void TMVA::MethodBase::MakeClass( const TString& theClassFileName ) const
2996{
2997 // the default consists of
2998 TString classFileName = "";
2999 if (theClassFileName == "")
3000 classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class.C";
3001 else
3002 classFileName = theClassFileName;
3003
3004 TString className = TString("Read") + GetMethodName();
3005
3006 TString tfname( classFileName );
3007 Log() << kINFO //<<Form("Dataset[%s] : ",DataInfo().GetName())
3008 << "Creating standalone class: "
3009 << gTools().Color("lightblue") << classFileName << gTools().Color("reset") << Endl;
3010
3011 std::ofstream fout( classFileName );
3012 if (!fout.good()) { // file could not be opened --> Error
3013 Log() << kFATAL << "<MakeClass> Unable to open file: " << classFileName << Endl;
3014 }
3015
3016 // now create the class
3017 // preamble
3018 fout << "// Class: " << className << std::endl;
3019 fout << "// Automatically generated by MethodBase::MakeClass" << std::endl << "//" << std::endl;
3020
3021 // print general information and configuration state
3022 fout << std::endl;
3023 fout << "/* configuration options =====================================================" << std::endl << std::endl;
3024 WriteStateToStream( fout );
3025 fout << std::endl;
3026 fout << "============================================================================ */" << std::endl;
3027
3028 // generate the class
3029 fout << "" << std::endl;
3030 fout << "#include <array>" << std::endl;
3031 fout << "#include <vector>" << std::endl;
3032 fout << "#include <cmath>" << std::endl;
3033 fout << "#include <string>" << std::endl;
3034 fout << "#include <iostream>" << std::endl;
3035 fout << "" << std::endl;
3036 // now if the classifier needs to write some additional classes for its response implementation
3037 // this code goes here: (at least the header declarations need to come before the main class
3038 this->MakeClassSpecificHeader( fout, className );
3039
3040 fout << "#ifndef IClassifierReader__def" << std::endl;
3041 fout << "#define IClassifierReader__def" << std::endl;
3042 fout << std::endl;
3043 fout << "class IClassifierReader {" << std::endl;
3044 fout << std::endl;
3045 fout << " public:" << std::endl;
3046 fout << std::endl;
3047 fout << " // constructor" << std::endl;
3048 fout << " IClassifierReader() : fStatusIsClean( true ) {}" << std::endl;
3049 fout << " virtual ~IClassifierReader() {}" << std::endl;
3050 fout << std::endl;
3051 fout << " // return classifier response" << std::endl;
3052 if(GetAnalysisType() == Types::kMulticlass) {
3053 fout << " virtual std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3054 } else {
3055 fout << " virtual double GetMvaValue( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3056 }
3057 fout << std::endl;
3058 fout << " // returns classifier status" << std::endl;
3059 fout << " bool IsStatusClean() const { return fStatusIsClean; }" << std::endl;
3060 fout << std::endl;
3061 fout << " protected:" << std::endl;
3062 fout << std::endl;
3063 fout << " bool fStatusIsClean;" << std::endl;
3064 fout << "};" << std::endl;
3065 fout << std::endl;
3066 fout << "#endif" << std::endl;
3067 fout << std::endl;
3068 fout << "class " << className << " : public IClassifierReader {" << std::endl;
3069 fout << std::endl;
3070 fout << " public:" << std::endl;
3071 fout << std::endl;
3072 fout << " // constructor" << std::endl;
3073 fout << " " << className << "( std::vector<std::string>& theInputVars )" << std::endl;
3074 fout << " : IClassifierReader()," << std::endl;
3075 fout << " fClassName( \"" << className << "\" )," << std::endl;
3076 fout << " fNvars( " << GetNvar() << " )" << std::endl;
3077 fout << " {" << std::endl;
3078 fout << " // the training input variables" << std::endl;
3079 fout << " const char* inputVars[] = { ";
3080 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3081 fout << "\"" << GetOriginalVarName(ivar) << "\"";
3082 if (ivar<GetNvar()-1) fout << ", ";
3083 }
3084 fout << " };" << std::endl;
3085 fout << std::endl;
3086 fout << " // sanity checks" << std::endl;
3087 fout << " if (theInputVars.size() <= 0) {" << std::endl;
3088 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": empty input vector\" << std::endl;" << std::endl;
3089 fout << " fStatusIsClean = false;" << std::endl;
3090 fout << " }" << std::endl;
3091 fout << std::endl;
3092 fout << " if (theInputVars.size() != fNvars) {" << std::endl;
3093 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in number of input values: \"" << std::endl;
3094 fout << " << theInputVars.size() << \" != \" << fNvars << std::endl;" << std::endl;
3095 fout << " fStatusIsClean = false;" << std::endl;
3096 fout << " }" << std::endl;
3097 fout << std::endl;
3098 fout << " // validate input variables" << std::endl;
3099 fout << " for (size_t ivar = 0; ivar < theInputVars.size(); ivar++) {" << std::endl;
3100 fout << " if (theInputVars[ivar] != inputVars[ivar]) {" << std::endl;
3101 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in input variable names\" << std::endl" << std::endl;
3102 fout << " << \" for variable [\" << ivar << \"]: \" << theInputVars[ivar].c_str() << \" != \" << inputVars[ivar] << std::endl;" << std::endl;
3103 fout << " fStatusIsClean = false;" << std::endl;
3104 fout << " }" << std::endl;
3105 fout << " }" << std::endl;
3106 fout << std::endl;
3107 fout << " // initialize min and max vectors (for normalisation)" << std::endl;
3108 for (UInt_t ivar = 0; ivar < GetNvar(); ivar++) {
3109 fout << " fVmin[" << ivar << "] = " << std::setprecision(15) << GetXmin( ivar ) << ";" << std::endl;
3110 fout << " fVmax[" << ivar << "] = " << std::setprecision(15) << GetXmax( ivar ) << ";" << std::endl;
3111 }
3112 fout << std::endl;
3113 fout << " // initialize input variable types" << std::endl;
3114 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3115 fout << " fType[" << ivar << "] = \'" << DataInfo().GetVariableInfo(ivar).GetVarType() << "\';" << std::endl;
3116 }
3117 fout << std::endl;
3118 fout << " // initialize constants" << std::endl;
3119 fout << " Initialize();" << std::endl;
3120 fout << std::endl;
3121 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
3122 fout << " // initialize transformation" << std::endl;
3123 fout << " InitTransform();" << std::endl;
3124 }
3125 fout << " }" << std::endl;
3126 fout << std::endl;
3127 fout << " // destructor" << std::endl;
3128 fout << " virtual ~" << className << "() {" << std::endl;
3129 fout << " Clear(); // method-specific" << std::endl;
3130 fout << " }" << std::endl;
3131 fout << std::endl;
3132 fout << " // the classifier response" << std::endl;
3133 fout << " // \"inputValues\" is a vector of input values in the same order as the" << std::endl;
3134 fout << " // variables given to the constructor" << std::endl;
3135 if(GetAnalysisType() == Types::kMulticlass) {
3136 fout << " std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const override;" << std::endl;
3137 } else {
3138 fout << " double GetMvaValue( const std::vector<double>& inputValues ) const override;" << std::endl;
3139 }
3140 fout << std::endl;
3141 fout << " private:" << std::endl;
3142 fout << std::endl;
3143 fout << " // method-specific destructor" << std::endl;
3144 fout << " void Clear();" << std::endl;
3145 fout << std::endl;
3146 if (GetTransformationHandler().GetTransformationList().GetSize()!=0) {
3147 fout << " // input variable transformation" << std::endl;
3148 GetTransformationHandler().MakeFunction(fout, className,1);
3149 fout << " void InitTransform();" << std::endl;
3150 fout << " void Transform( std::vector<double> & iv, int sigOrBgd ) const;" << std::endl;
3151 fout << std::endl;
3152 }
3153 fout << " // common member variables" << std::endl;
3154 fout << " const char* fClassName;" << std::endl;
3155 fout << std::endl;
3156 fout << " const size_t fNvars;" << std::endl;
3157 fout << " size_t GetNvar() const { return fNvars; }" << std::endl;
3158 fout << " char GetType( int ivar ) const { return fType[ivar]; }" << std::endl;
3159 fout << std::endl;
3160 fout << " // normalisation of input variables" << std::endl;
3161 fout << " double fVmin[" << GetNvar() << "];" << std::endl;
3162 fout << " double fVmax[" << GetNvar() << "];" << std::endl;
3163 fout << " double NormVariable( double x, double xmin, double xmax ) const {" << std::endl;
3164 fout << " // normalise to output range: [-1, 1]" << std::endl;
3165 fout << " return 2*(x - xmin)/(xmax - xmin) - 1.0;" << std::endl;
3166 fout << " }" << std::endl;
3167 fout << std::endl;
3168 fout << " // type of input variable: 'F' or 'I'" << std::endl;
3169 fout << " char fType[" << GetNvar() << "];" << std::endl;
3170 fout << std::endl;
3171 fout << " // initialize internal variables" << std::endl;
3172 fout << " void Initialize();" << std::endl;
3173 if(GetAnalysisType() == Types::kMulticlass) {
3174 fout << " std::vector<double> GetMulticlassValues__( const std::vector<double>& inputValues ) const;" << std::endl;
3175 } else {
3176 fout << " double GetMvaValue__( const std::vector<double>& inputValues ) const;" << std::endl;
3177 }
3178 fout << "" << std::endl;
3179 fout << " // private members (method specific)" << std::endl;
3180
3181 // call the classifier specific output (the classifier must close the class !)
3182 MakeClassSpecific( fout, className );
3183
3184 if(GetAnalysisType() == Types::kMulticlass) {
3185 fout << "inline std::vector<double> " << className << "::GetMulticlassValues( const std::vector<double>& inputValues ) const" << std::endl;
3186 } else {
3187 fout << "inline double " << className << "::GetMvaValue( const std::vector<double>& inputValues ) const" << std::endl;
3188 }
3189 fout << "{" << std::endl;
3190 fout << " // classifier response value" << std::endl;
3191 if(GetAnalysisType() == Types::kMulticlass) {
3192 fout << " std::vector<double> retval;" << std::endl;
3193 } else {
3194 fout << " double retval = 0;" << std::endl;
3195 }
3196 fout << std::endl;
3197 fout << " // classifier response, sanity check first" << std::endl;
3198 fout << " if (!IsStatusClean()) {" << std::endl;
3199 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": cannot return classifier response\"" << std::endl;
3200 fout << " << \" because status is dirty\" << std::endl;" << std::endl;
3201 fout << " }" << std::endl;
3202 fout << " else {" << std::endl;
3203 if (IsNormalised()) {
3204 fout << " // normalise variables" << std::endl;
3205 fout << " std::vector<double> iV;" << std::endl;
3206 fout << " iV.reserve(inputValues.size());" << std::endl;
3207 fout << " int ivar = 0;" << std::endl;
3208 fout << " for (std::vector<double>::const_iterator varIt = inputValues.begin();" << std::endl;
3209 fout << " varIt != inputValues.end(); varIt++, ivar++) {" << std::endl;
3210 fout << " iV.push_back(NormVariable( *varIt, fVmin[ivar], fVmax[ivar] ));" << std::endl;
3211 fout << " }" << std::endl;
3212 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3213 GetMethodType() != Types::kHMatrix) {
3214 fout << " Transform( iV, -1 );" << std::endl;
3215 }
3216
3217 if(GetAnalysisType() == Types::kMulticlass) {
3218 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3219 } else {
3220 fout << " retval = GetMvaValue__( iV );" << std::endl;
3221 }
3222 } else {
3223 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3224 GetMethodType() != Types::kHMatrix) {
3225 fout << " std::vector<double> iV(inputValues);" << std::endl;
3226 fout << " Transform( iV, -1 );" << std::endl;
3227 if(GetAnalysisType() == Types::kMulticlass) {
3228 fout << " retval = GetMulticlassValues__( iV );" << std::endl;
3229 } else {
3230 fout << " retval = GetMvaValue__( iV );" << std::endl;
3231 }
3232 } else {
3233 if(GetAnalysisType() == Types::kMulticlass) {
3234 fout << " retval = GetMulticlassValues__( inputValues );" << std::endl;
3235 } else {
3236 fout << " retval = GetMvaValue__( inputValues );" << std::endl;
3237 }
3238 }
3239 }
3240 fout << " }" << std::endl;
3241 fout << std::endl;
3242 fout << " return retval;" << std::endl;
3243 fout << "}" << std::endl;
3244
3245 // create output for transformation - if any
3246 if (GetTransformationHandler().GetTransformationList().GetSize()!=0)
3247 GetTransformationHandler().MakeFunction(fout, className,2);
3248
3249 // close the file
3250 fout.close();
3251}
3252
3253////////////////////////////////////////////////////////////////////////////////
3254/// prints out method-specific help method
3255
3257{
3258 // if options are written to reference file, also append help info
3259 std::streambuf* cout_sbuf = std::cout.rdbuf(); // save original sbuf
3260 std::ofstream* o = 0;
3261 if (gConfig().WriteOptionsReference()) {
3262 Log() << kINFO << "Print Help message for class " << GetName() << " into file: " << GetReferenceFile() << Endl;
3263 o = new std::ofstream( GetReferenceFile(), std::ios::app );
3264 if (!o->good()) { // file could not be opened --> Error
3265 Log() << kFATAL << "<PrintHelpMessage> Unable to append to output file: " << GetReferenceFile() << Endl;
3266 }
3267 std::cout.rdbuf( o->rdbuf() ); // redirect 'std::cout' to file
3268 }
3269
3270 // "|--------------------------------------------------------------|"
3271 if (!o) {
3272 Log() << kINFO << Endl;
3273 Log() << gTools().Color("bold")
3274 << "================================================================"
3275 << gTools().Color( "reset" )
3276 << Endl;
3277 Log() << gTools().Color("bold")
3278 << "H e l p f o r M V A m e t h o d [ " << GetName() << " ] :"
3279 << gTools().Color( "reset" )
3280 << Endl;
3281 }
3282 else {
3283 Log() << "Help for MVA method [ " << GetName() << " ] :" << Endl;
3284 }
3285
3286 // print method-specific help message
3287 GetHelpMessage();
3288
3289 if (!o) {
3290 Log() << Endl;
3291 Log() << "<Suppress this message by specifying \"!H\" in the booking option>" << Endl;
3292 Log() << gTools().Color("bold")
3293 << "================================================================"
3294 << gTools().Color( "reset" )
3295 << Endl;
3296 Log() << Endl;
3297 }
3298 else {
3299 // indicate END
3300 Log() << "# End of Message___" << Endl;
3301 }
3302
3303 std::cout.rdbuf( cout_sbuf ); // restore the original stream buffer
3304 if (o) o->close();
3305}
3306
3307// ----------------------- r o o t f i n d i n g ----------------------------
3308
3309////////////////////////////////////////////////////////////////////////////////
3310/// returns efficiency as function of cut
3311
3313{
3314 Double_t retval=0;
3315
3316 // retrieve the class object
3318 retval = fSplRefS->Eval( theCut );
3319 }
3320 else retval = fEffS->GetBinContent( fEffS->FindBin( theCut ) );
3321
3322 // caution: here we take some "forbidden" action to hide a problem:
3323 // in some cases, in particular for likelihood, the binned efficiency distributions
3324 // do not equal 1, at xmin, and 0 at xmax; of course, in principle we have the
3325 // unbinned information available in the trees, but the unbinned minimization is
3326 // too slow, and we don't need to do a precision measurement here. Hence, we force
3327 // this property.
3328 Double_t eps = 1.0e-5;
3329 if (theCut-fXmin < eps) retval = (GetCutOrientation() == kPositive) ? 1.0 : 0.0;
3330 else if (fXmax-theCut < eps) retval = (GetCutOrientation() == kPositive) ? 0.0 : 1.0;
3331
3332 return retval;
3333}
3334
3335////////////////////////////////////////////////////////////////////////////////
3336/// returns the event collection (i.e. the dataset) TRANSFORMED using the
3337/// classifiers specific Variable Transformation (e.g. Decorr or Decorr:Gauss:Decorr)
3338
3340{
3341 // if there's no variable transformation for this classifier, just hand back the
3342 // event collection of the data set
3343 if (GetTransformationHandler().GetTransformationList().GetEntries() <= 0) {
3344 return (Data()->GetEventCollection(type));
3345 }
3346
3347 // otherwise, transform ALL the events and hand back the vector of the pointers to the
3348 // transformed events. If the pointer is already != 0, i.e. the whole thing has been
3349 // done before, I don't need to do it again, but just "hand over" the pointer to those events.
3350 Int_t idx = Data()->TreeIndex(type); //index indicating Training,Testing,... events/datasets
3351 if (fEventCollections.at(idx) == 0) {
3352 fEventCollections.at(idx) = &(Data()->GetEventCollection(type));
3353 fEventCollections.at(idx) = GetTransformationHandler().CalcTransformations(*(fEventCollections.at(idx)),kTRUE);
3354 }
3355 return *(fEventCollections.at(idx));
3356}
3357
3358////////////////////////////////////////////////////////////////////////////////
3359/// calculates the TMVA version string from the training version code on the fly
3360
3362{
3363 UInt_t a = GetTrainingTMVAVersionCode() & 0xff0000; a>>=16;
3364 UInt_t b = GetTrainingTMVAVersionCode() & 0x00ff00; b>>=8;
3365 UInt_t c = GetTrainingTMVAVersionCode() & 0x0000ff;
3366
3367 return TString(Form("%i.%i.%i",a,b,c));
3368}
3369
3370////////////////////////////////////////////////////////////////////////////////
3371/// calculates the ROOT version string from the training version code on the fly
3372
3374{
3375 UInt_t a = GetTrainingROOTVersionCode() & 0xff0000; a>>=16;
3376 UInt_t b = GetTrainingROOTVersionCode() & 0x00ff00; b>>=8;
3377 UInt_t c = GetTrainingROOTVersionCode() & 0x0000ff;
3378
3379 return TString(Form("%i.%02i/%02i",a,b,c));
3380}
3381
3382////////////////////////////////////////////////////////////////////////////////
3383
3385 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
3386 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
3387
3388 if (mvaRes != NULL) {
3389 TH1D *mva_s = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_S"));
3390 TH1D *mva_b = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_B"));
3391 TH1D *mva_s_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_S"));
3392 TH1D *mva_b_tr = dynamic_cast<TH1D*> (mvaRes->GetHist("MVA_TRAIN_B"));
3393
3394 if ( !mva_s || !mva_b || !mva_s_tr || !mva_b_tr) return -1;
3395
3396 if (SorB == 's' || SorB == 'S')
3397 return mva_s->KolmogorovTest( mva_s_tr, opt.Data() );
3398 else
3399 return mva_b->KolmogorovTest( mva_b_tr, opt.Data() );
3400 }
3401 return -1;
3402}
const Bool_t Use_Splines_for_Eff_
Definition: MethodBase.cxx:132
const Int_t NBIN_HIST_HIGH
Definition: MethodBase.cxx:135
ROOT::R::TRInterface & r
Definition: Object.C:4
#define d(i)
Definition: RSha256.hxx:102
#define b(i)
Definition: RSha256.hxx:100
#define c(i)
Definition: RSha256.hxx:101
#define s1(x)
Definition: RSha256.hxx:91
#define ROOT_VERSION_CODE
Definition: RVersion.h:21
static RooMathCoreReg dummy
int Int_t
Definition: RtypesCore.h:43
char Char_t
Definition: RtypesCore.h:31
const Bool_t kFALSE
Definition: RtypesCore.h:90
bool Bool_t
Definition: RtypesCore.h:61
double Double_t
Definition: RtypesCore.h:57
long long Long64_t
Definition: RtypesCore.h:71
float Float_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:89
#define ClassImp(name)
Definition: Rtypes.h:361
char name[80]
Definition: TGX11.cxx:109
int type
Definition: TGX11.cxx:120
float xmin
Definition: THbookFile.cxx:93
float xmax
Definition: THbookFile.cxx:93
double sqrt(double)
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:556
#define TMVA_VERSION_CODE
Definition: Version.h:47
Class to manage histogram axis.
Definition: TAxis.h:30
Double_t GetXmax() const
Definition: TAxis.h:134
Double_t GetXmin() const
Definition: TAxis.h:133
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:182
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write all objects in this collection.
This class stores the date and time with a precision of one second in an unsigned 32 bit word (950130...
Definition: TDatime.h:37
const char * AsString() const
Return the date & time as a string (ctime() format).
Definition: TDatime.cxx:101
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
Describe directory structure in memory.
Definition: TDirectory.h:40
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
Definition: TDirectory.cxx:401
virtual TDirectory * mkdir(const char *name, const char *title="", Bool_t returnExistingDirectory=kFALSE)
Create a sub-directory "a" or a hierarchy of sub-directories "a/b/c/...".
virtual Bool_t cd(const char *path=nullptr)
Change current directory to "this" directory.
Definition: TDirectory.cxx:498
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3942
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:873
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
1-D histogram with a double per channel (see TH1 documentation)}
Definition: TH1.h:614
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:571
The TH1 histogram class.
Definition: TH1.h:56
virtual Double_t GetBinCenter(Int_t bin) const
Return bin center for 1D histogram.
Definition: TH1.cxx:8597
virtual Int_t GetQuantiles(Int_t nprobSum, Double_t *q, const Double_t *probSum=0)
Compute Quantiles for this histogram Quantile x_q of a probability distribution Function F is defined...
Definition: TH1.cxx:4459
virtual void AddBinContent(Int_t bin)
Increment bin content by 1.
Definition: TH1.cxx:1201
virtual Double_t GetMean(Int_t axis=1) const
For axis = 1,2 or 3 returns the mean value of the histogram along X,Y or Z axis.
Definition: TH1.cxx:7086
virtual void SetXTitle(const char *title)
Definition: TH1.h:409
static void AddDirectory(Bool_t add=kTRUE)
Sets the flag controlling the automatic add of histograms in memory.
Definition: TH1.cxx:1226
TAxis * GetXaxis()
Get the behaviour adopted by the object about the statoverflows. See EStatOverflows for more informat...
Definition: TH1.h:316
virtual Double_t GetMaximum(Double_t maxval=FLT_MAX) const
Return maximum value smaller than maxval of bins in the range, unless the value has been overridden b...
Definition: TH1.cxx:8006
virtual Int_t GetNbinsX() const
Definition: TH1.h:292
virtual Int_t Fill(Double_t x)
Increment bin with abscissa X by 1.
Definition: TH1.cxx:3275
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content see convention for numbering bins in TH1::GetBin In case the bin number is greater th...
Definition: TH1.cxx:8678
virtual Int_t GetMaximumBin() const
Return location of bin with maximum value in the range.
Definition: TH1.cxx:8036
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition: TH1.cxx:4907
virtual void SetYTitle(const char *title)
Definition: TH1.h:410
virtual void Scale(Double_t c1=1, Option_t *option="")
Multiply this histogram by a constant c1.
Definition: TH1.cxx:6246
virtual Int_t FindBin(Double_t x, Double_t y=0, Double_t z=0)
Return Global bin number corresponding to x,y,z.
Definition: TH1.cxx:3596
virtual Double_t KolmogorovTest(const TH1 *h2, Option_t *option="") const
Statistical test of compatibility in shape between this histogram and h2, using Kolmogorov test.
Definition: TH1.cxx:7684
virtual void Sumw2(Bool_t flag=kTRUE)
Create structure to store sum of squares of weights.
Definition: TH1.cxx:8476
static Bool_t AddDirectoryStatus()
Static function: cannot be inlined on Windows/NT.
Definition: TH1.cxx:706
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:251
Int_t Fill(Double_t)
Invalid Fill method.
Definition: TH2.cxx:294
A doubly linked list.
Definition: TList.h:44
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:356
Class that contains all the information of a class.
Definition: ClassInfo.h:49
UInt_t GetNumber() const
Definition: ClassInfo.h:65
TString fWeightFileExtension
Definition: Config.h:125
VariablePlotting & GetVariablePlotting()
Definition: Config.h:99
class TMVA::Config::VariablePlotting fVariablePlotting
IONames & GetIONames()
Definition: Config.h:100
MsgLogger * fLogger
Definition: Configurable.h:128
Class that contains all the data information.
Definition: DataSetInfo.h:60
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:391
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
static void SetIgnoreNegWeightsInTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:400
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
void Init(std::vector< TString > &graphTitles)
This function gets some title and it creates a TGraph for every title.
Definition: MethodBase.cxx:171
IPythonInteractive()
standard constructor
Definition: MethodBase.cxx:148
~IPythonInteractive()
standard destructor
Definition: MethodBase.cxx:156
void ClearGraphs()
This function sets the point number to 0 for all graphs.
Definition: MethodBase.cxx:195
void AddPoint(Double_t x, Double_t y1, Double_t y2)
This function is used only in 2 TGraph case, and it will add new data points to graphs.
Definition: MethodBase.cxx:209
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TDirectory * MethodBaseDir() const
returns the ROOT directory where all instances of the corresponding MVA method are stored
virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt="X")
MethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
standard constructor
Definition: MethodBase.cxx:239
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
void ReadClassesFromXML(void *clsnode)
read number of classes from XML
void SetWeightFileDir(TString fileDir)
set directory of weight file
void WriteStateToXML(void *parent) const
general method used in writing the header of the weight files where the used variables,...
void DeclareBaseOptions()
define the options (their key words) that can be set in the option string here the options valid for ...
Definition: MethodBase.cxx:511
virtual void TestRegression(Double_t &bias, Double_t &biasT, Double_t &dev, Double_t &devT, Double_t &rms, Double_t &rmsT, Double_t &mInf, Double_t &mInfT, Double_t &corr, Types::ETreeType type)
calculate <sum-of-deviation-squared> of regression output versus "true" value from test sample
Definition: MethodBase.cxx:979
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:598
virtual Double_t GetSignificance() const
compute significance of mean difference
virtual Double_t GetProba(const Event *ev)
const char * GetName() const
Definition: MethodBase.h:333
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
Construct a confusion matrix for a multiclass classifier.
void PrintHelpMessage() const
prints out method-specific help method
virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype)
writes all MVA evaluation histograms to file
virtual void TestMulticlass()
test multiclass classification
const std::vector< TMVA::Event * > & GetEventCollection(Types::ETreeType type)
returns the event collection (i.e.
void SetupMethod()
setup of methods
Definition: MethodBase.cxx:408
TDirectory * BaseDir() const
returns the ROOT directory where info/histograms etc of the corresponding MVA method instance are sto...
virtual std::vector< Float_t > GetMulticlassEfficiency(std::vector< std::vector< Float_t > > &purity)
void AddInfoItem(void *gi, const TString &name, const TString &value) const
xml writing
virtual void AddClassifierOutputProb(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:938
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
TString GetTrainingTMVAVersionString() const
calculates the TMVA version string from the training version code on the fly
void Statistics(Types::ETreeType treeType, const TString &theVarName, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &)
calculates rms,mean, xmin, xmax of the event variable this can be either done for the variables as th...
Bool_t GetLine(std::istream &fin, char *buf)
reads one line from the input stream checks for certain keywords and interprets the line if keywords ...
void ProcessSetup()
process all options the "CheckForUnusedOptions" is done in an independent call, since it may be overr...
Definition: MethodBase.cxx:425
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
Definition: MethodBase.cxx:896
virtual Bool_t IsSignalLike()
uses a pre-set cut on the MVA output (SetSignalReferenceCut and SetSignalReferenceCutOrientation) for...
Definition: MethodBase.cxx:856
virtual ~MethodBase()
destructor
Definition: MethodBase.cxx:366
virtual Double_t GetMaximumSignificance(Double_t SignalEvents, Double_t BackgroundEvents, Double_t &optimal_significance_value) const
plot significance, , curve for given number of signal and background events; returns cut for maximum ...
virtual Double_t GetTrainingEfficiency(const TString &)
void SetWeightFileName(TString)
set the weight file name (depreciated)
virtual void MakeClass(const TString &classFileName=TString("")) const
create reader class for method (classification only at present)
TString GetWeightFileName() const
retrieve weight file name
virtual void TestClassification()
initialization
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
virtual void WriteMonitoringHistosToFile() const
write special monitoring histograms to file dummy implementation here --------------—
virtual void AddRegressionOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:746
void InitBase()
default initialization called by all constructors
Definition: MethodBase.cxx:443
virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t &stddev, Double_t &stddev90Percent) const
Definition: MethodBase.cxx:726
void ReadStateFromXMLString(const char *xmlstr)
for reading from memory
void CreateMVAPdfs()
Create PDFs of the MVA output variables.
TString GetTrainingROOTVersionString() const
calculates the ROOT version string from the training version code on the fly
virtual Double_t GetValueForRoot(Double_t)
returns efficiency as function of cut
void ReadStateFromFile()
Function to write options and weights to file.
void WriteVarsToStream(std::ostream &tf, const TString &prefix="") const
write the list of variables (name, min, max) for a given data transformation method to the stream
void ReadVarsFromStream(std::istream &istr)
Read the variables (name, min, max) for a given data transformation method from the stream.
void ReadSpectatorsFromXML(void *specnode)
read spectator info from XML
virtual Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)=0
void SetTestvarName(const TString &v="")
Definition: MethodBase.h:340
void ReadVariablesFromXML(void *varnode)
read variable info from XML
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:625
virtual std::vector< Float_t > GetMulticlassTrainingEfficiency(std::vector< std::vector< Float_t > > &purity)
void WriteStateToStream(std::ostream &tf) const
general method used in writing the header of the weight files where the used variables,...
virtual Double_t GetRarity(Double_t mvaVal, Types::ESBType reftype=Types::kBackground) const
compute rarity:
virtual void SetTuneParameters(std::map< TString, Double_t > tuneParameters)
set the tuning parameters according to the argument This is just a dummy .
Definition: MethodBase.cxx:646
void ReadStateFromStream(std::istream &tf)
read the header from the weight files of the different MVA methods
void AddVarsXMLTo(void *parent) const
write variable info to XML
void AddTargetsXMLTo(void *parent) const
write target info to XML
void ReadTargetsFromXML(void *tarnode)
read target info from XML
void ProcessBaseOptions()
the option string is decoded, for available options see "DeclareOptions"
Definition: MethodBase.cxx:542
void ReadStateFromXML(void *parent)
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:838
void WriteStateToFile() const
write options and weights to file note that each one text file for the main configuration information...
void AddClassesXMLTo(void *parent) const
write class info to XML
virtual void AddClassifierOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:870
void AddSpectatorsXMLTo(void *parent) const
write spectator info to XML
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const
calculate the area (integral) under the ROC curve as a overall quality measure of the classification
virtual void AddMulticlassOutput(Types::ETreeType type)
prepare tree branch with the method's discriminating variable
Definition: MethodBase.cxx:796
virtual void CheckSetup()
check may be overridden by derived class (sometimes, eg, fitters are used which can only be implement...
Definition: MethodBase.cxx:435
void SetSource(const std::string &source)
Definition: MsgLogger.h:70
PDF wrapper for histograms; uses user-defined spline interpolation.
Definition: PDF.h:63
Double_t GetXmin() const
Definition: PDF.h:104
Double_t GetXmax() const
Definition: PDF.h:105
Double_t GetVal(Double_t x) const
returns value PDF(x)
Definition: PDF.cxx:700
@ kSpline3
Definition: PDF.h:70
@ kSpline2
Definition: PDF.h:70
Double_t GetIntegral(Double_t xmin, Double_t xmax)
computes PDF integral within given ranges
Definition: PDF.cxx:653
Class that is the base-class for a vector of result.
std::vector< Bool_t > * GetValueVectorTypes()
void SetValue(Float_t value, Int_t ievt)
set MVA response
std::vector< Float_t > * GetValueVector()
Class which takes the results of a multiclass classification.
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
Float_t GetAchievablePur(UInt_t cls)
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
Float_t GetAchievableEff(UInt_t cls)
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
Class that is the base-class for a vector of result.
Class that is the base-class for a vector of result.
Definition: Results.h:57
Bool_t DoesExist(const TString &alias) const
Returns true if there is an object stored in the result for a given alias, false otherwise.
Definition: Results.cxx:127
TH1 * GetHist(const TString &alias) const
Definition: Results.cxx:136
TList * GetStorage() const
Definition: Results.h:73
void Store(TObject *obj, const char *alias=0)
Definition: Results.cxx:86
Root finding using Brents algorithm (translated from CERNLIB function RZERO)
Definition: RootFinder.h:48
Double_t Root(Double_t refValue)
Root finding using Brents algorithm; taken from CERNLIB function RZERO.
Definition: RootFinder.cxx:72
Linear interpolation of TGraph.
Definition: TSpline1.h:43
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
Double_t ElapsedSeconds(void)
computes elapsed tim in seconds
Definition: Timer.cxx:138
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:147
void DrawProgressBar(Int_t, const TString &comment="")
draws progress bar in color or B&W caution:
Definition: Timer.cxx:203
void ComputeStat(const std::vector< TMVA::Event * > &, std::vector< Float_t > *, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Double_t &, Int_t signalClass, Bool_t norm=kFALSE)
sanity check
Definition: Tools.cxx:213
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition: Tools.cxx:412
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1173
Double_t GetSeparation(TH1 *S, TH1 *B) const
compute "separation" defined as
Definition: Tools.cxx:132
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1135
Double_t GetMutualInformation(const TH2F &)
Mutual Information method for non-linear correlations estimates in 2D histogram Author: Moritz Backes...
Definition: Tools.cxx:600
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:839
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1161
TXMLEngine & xmlengine()
Definition: Tools.h:268
Bool_t CheckSplines(const TH1 *, const TSpline *)
check quality of splining by comparing splines and histograms in each bin
Definition: Tools.cxx:490
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:394
Singleton class for Global types used by TMVA.
Definition: Types.h:73
@ kSignal
Definition: Types.h:136
@ kBackground
Definition: Types.h:137
@ kLikelihood
Definition: Types.h:81
@ kHMatrix
Definition: Types.h:83
EAnalysisType
Definition: Types.h:127
@ kMulticlass
Definition: Types.h:130
@ kNoAnalysisType
Definition: Types.h:131
@ kClassification
Definition: Types.h:128
@ kMaxAnalysisType
Definition: Types.h:132
@ kRegression
Definition: Types.h:129
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
Linear interpolation class.
Gaussian Transformation of input variables.
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
void ReadFromXML(void *varnode)
read VariableInfo from stream
const TString & GetExpression() const
Definition: VariableInfo.h:57
char GetVarType() const
Definition: VariableInfo.h:61
void ReadFromStream(std::istream &istr)
read VariableInfo from stream
void AddToXML(void *varnode)
write class to XML
void SetExternalLink(void *p)
Definition: VariableInfo.h:75
void * GetExternalLink() const
Definition: VariableInfo.h:83
void BuildTransformationFromVarInfo(const std::vector< TMVA::VariableInfo > &var)
this method is only used when building a normalization transformation from old text files in this cas...
Linear interpolation class.
Linear interpolation class.
virtual void ReadTransformationFromStream(std::istream &istr, const TString &classname="")=0
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:36
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
Collectable string class.
Definition: TObjString.h:28
virtual Int_t Write(const char *name=0, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition: TObject.cxx:796
Basic string class.
Definition: TString.h:131
Ssiz_t Length() const
Definition: TString.h:405
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1125
Int_t Atoi() const
Return integer value of string.
Definition: TString.cxx:1921
Bool_t EndsWith(const char *pat, ECaseCompare cmp=kExact) const
Return true if string ends with the specified string.
Definition: TString.cxx:2177
TSubString Strip(EStripType s=kTrailing, char c=' ') const
Return a substring of self stripped at beginning and/or end.
Definition: TString.cxx:1106
const char * Data() const
Definition: TString.h:364
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:687
@ kLeading
Definition: TString.h:262
Ssiz_t Last(char c) const
Find last occurrence of a character c.
Definition: TString.cxx:892
Bool_t IsNull() const
Definition: TString.h:402
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2311
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:634
virtual const char * GetBuildNode() const
Return the build node name.
Definition: TSystem.cxx:3875
virtual int mkdir(const char *name, Bool_t recursive=kFALSE)
Make a file system directory.
Definition: TSystem.cxx:902
virtual const char * WorkingDirectory()
Return working directory.
Definition: TSystem.cxx:867
virtual UserGroup_t * GetUserInfo(Int_t uid)
Returns all user info in the UserGroup_t structure.
Definition: TSystem.cxx:1594
void SaveDoc(XMLDocPointer_t xmldoc, const char *filename, Int_t layout=1)
store document content to file if layout<=0, no any spaces or newlines will be placed between xmlnode...
void FreeDoc(XMLDocPointer_t xmldoc)
frees allocated document data and deletes document itself
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
XMLDocPointer_t NewDoc(const char *version="1.0")
creates new xml document with provided version
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLDocPointer_t ParseString(const char *xmlstring)
parses content of string and tries to produce xml structures
void DocSetRootElement(XMLDocPointer_t xmldoc, XMLNodePointer_t xmlnode)
set main (root) node for document
TLine * line
Double_t x[n]
Definition: legend1.C:17
TH1F * h1
Definition: legend1.C:5
bool BeginsWith(const std::string &theString, const std::string &theSubstring)
void Init(TClassEdit::TInterpreterLookupHelper *helper)
Definition: TClassEdit.cxx:154
static constexpr double s
static constexpr double m2
void GetMethodName(TString &name, TKey *mkey)
Definition: tmvaglob.cxx:335
Config & gConfig()
Tools & gTools()
void CreateVariableTransforms(const TString &trafoDefinition, TMVA::DataSetInfo &dataInfo, TMVA::TransformationHandler &transformationHandler, TMVA::MsgLogger &log)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:212
Double_t Log(Double_t x)
Definition: TMath.h:750
Double_t Sqrt(Double_t x)
Definition: TMath.h:681
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:180
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
TString fUser
Definition: TSystem.h:140
auto * a
Definition: textangle.C:12
REAL epsilon
Definition: triangle.c:617