Logo ROOT   6.14/05
Reference Guide
CrossValidation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou
3 // Modified: Kim Albertsson 2017
4 
5 /*************************************************************************
6  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
7  * All rights reserved. *
8  * *
9  * For the licensing terms see $ROOTSYS/LICENSE. *
10  * For the list of contributors see $ROOTSYS/README/CREDITS. *
11  *************************************************************************/
12 
13 #include "TMVA/CrossValidation.h"
14 
15 #include "TMVA/ClassifierFactory.h"
16 #include "TMVA/Config.h"
17 #include "TMVA/CvSplit.h"
18 #include "TMVA/DataSet.h"
19 #include "TMVA/Event.h"
20 #include "TMVA/MethodBase.h"
22 #include "TMVA/MsgLogger.h"
24 #include "TMVA/ResultsMulticlass.h"
25 #include "TMVA/ROCCurve.h"
26 #include "TMVA/tmvaglob.h"
27 #include "TMVA/Types.h"
28 
29 #include "TSystem.h"
30 #include "TAxis.h"
31 #include "TCanvas.h"
32 #include "TGraph.h"
33 #include "TMath.h"
34 
35 #include <iostream>
36 #include <memory>
37 
38 //_______________________________________________________________________
40 :fROCCurves(new TMultiGraph())
41 {
42  fSigs.resize(numFolds);
43  fSeps.resize(numFolds);
44  fEff01s.resize(numFolds);
45  fEff10s.resize(numFolds);
46  fEff30s.resize(numFolds);
47  fEffAreas.resize(numFolds);
48  fTrainEff01s.resize(numFolds);
49  fTrainEff10s.resize(numFolds);
50  fTrainEff30s.resize(numFolds);
51 }
52 
53 //_______________________________________________________________________
55 {
56  fROCs=obj.fROCs;
57  fROCCurves = obj.fROCCurves;
58 
59  fSigs = obj.fSigs;
60  fSeps = obj.fSeps;
61  fEff01s = obj.fEff01s;
62  fEff10s = obj.fEff10s;
63  fEff30s = obj.fEff30s;
64  fEffAreas = obj.fEffAreas;
68 }
69 
70 //_______________________________________________________________________
72 {
73  UInt_t iFold = fr.fFold;
74 
75  fROCs[iFold] = fr.fROCIntegral;
76  fROCCurves->Add(static_cast<TGraph *>(fr.fROC.Clone()));
77 
78  fSigs[iFold] = fr.fSig;
79  fSeps[iFold] = fr.fSep;
80  fEff01s[iFold] = fr.fEff01;
81  fEff10s[iFold] = fr.fEff10;
82  fEff30s[iFold] = fr.fEff30;
83  fEffAreas[iFold] = fr.fEffArea;
84  fTrainEff01s[iFold] = fr.fTrainEff01;
85  fTrainEff10s[iFold] = fr.fTrainEff10;
86  fTrainEff30s[iFold] = fr.fTrainEff30;
87 }
88 
89 //_______________________________________________________________________
91 {
92  return fROCCurves.get();
93 }
94 
95 //_______________________________________________________________________
97 {
98  Float_t avg=0;
99  for(auto &roc:fROCs) avg+=roc.second;
100  return avg/fROCs.size();
101 }
102 
103 //_______________________________________________________________________
105 {
106  // NOTE: We are using here the unbiased estimation of the standard deviation.
107  Float_t std=0;
108  Float_t avg=GetROCAverage();
109  for(auto &roc:fROCs) std+=TMath::Power(roc.second-avg, 2);
110  return TMath::Sqrt(std/float(fROCs.size()-1.0));
111 }
112 
113 //_______________________________________________________________________
115 {
118 
119  MsgLogger fLogger("CrossValidation");
120  fLogger << kHEADER << " ==== Results ====" << Endl;
121  for(auto &item:fROCs)
122  fLogger << kINFO << Form("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
123 
124  fLogger << kINFO << "------------------------" << Endl;
125  fLogger << kINFO << Form("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
126  fLogger << kINFO << Form("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
127 
129 }
130 
131 //_______________________________________________________________________
133 {
134  TCanvas *c=new TCanvas(name.Data());
135  fROCCurves->Draw("AL");
136  fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
137  fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
138  Float_t adjust=1+fROCs.size()*0.01;
139  c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
140  c->SetTitle("Cross Validation ROC Curves");
141  c->Draw();
142  return c;
143 }
144 
145 /**
146 * \class TMVA::CrossValidation
147 * \ingroup TMVA
148 * \brief
149 
150 Use html for explicit line breaking<br>
151 Markdown links? [class reference](#reference)?
152 
153 
154 ~~~{.cpp}
155 ce->BookMethod(dataloader, options);
156 ce->Evaluate();
157 ~~~
158 
159 Cross-evaluation will generate a new training and a test set dynamically from
160 from `K` folds. These `K` folds are generated by splitting the input training
161 set. The input test set is currently ignored.
162 
163 This means that when you specify your DataSet you should include all events
164 in your training set. One way of doing this would be the following:
165 
166 ~~~{.cpp}
167 dataloader->AddTree( signalTree, "cls1" );
168 dataloader->AddTree( background, "cls2" );
169 dataloader->PrepareTrainingAndTestTree( "", "", "nTest_cls1=1:nTest_cls2=1" );
170 ~~~
171 
172 ## Split Expression
173 See CVSplit documentation?
174 
175 */
176 
177 ////////////////////////////////////////////////////////////////////////////////
178 ///
179 
181  TString options)
182  : TMVA::Envelope(jobName, dataloader, nullptr, options),
183  fAnalysisType(Types::kMaxAnalysisType),
184  fAnalysisTypeStr("auto"),
185  fCorrelations(kFALSE),
186  fCvFactoryOptions(""),
187  fDrawProgressBar(kFALSE),
188  fFoldFileOutput(kFALSE),
189  fFoldStatus(kFALSE),
190  fJobName(jobName),
191  fNumFolds(2),
192  fNumWorkerProcs(1),
193  fOutputFactoryOptions(""),
194  fOutputFile(outputFile),
195  fSilent(kFALSE),
196  fSplitExprString(""),
197  fROC(kTRUE),
198  fTransformations(""),
199  fVerbose(kFALSE),
200  fVerboseLevel(kINFO)
201 {
202  InitOptions();
203  ParseOptions();
205 }
206 
207 ////////////////////////////////////////////////////////////////////////////////
208 ///
209 
211  : CrossValidation(jobName, dataloader, nullptr, options)
212 {
213 }
214 
215 ////////////////////////////////////////////////////////////////////////////////
216 ///
217 
219 
220 ////////////////////////////////////////////////////////////////////////////////
221 ///
222 
224 {
225  // Forwarding of Factory options
226  DeclareOptionRef(fSilent, "Silent",
227  "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
228  "class object (default: False)");
229  DeclareOptionRef(fVerbose, "V", "Verbose flag");
230  DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
231  AddPreDefVal(TString("Debug"));
232  AddPreDefVal(TString("Verbose"));
233  AddPreDefVal(TString("Info"));
234 
235  DeclareOptionRef(fTransformations, "Transformations",
236  "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
237  "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
238  "transformations");
239 
240  DeclareOptionRef(fDrawProgressBar, "DrawProgressBar", "Boolean to show draw progress bar");
241  DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation in output");
242  DeclareOptionRef(fROC, "ROC", "Boolean to show ROC in output");
243 
244  TString analysisType("Auto");
245  DeclareOptionRef(fAnalysisTypeStr, "AnalysisType",
246  "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
247  AddPreDefVal(TString("Classification"));
248  AddPreDefVal(TString("Regression"));
249  AddPreDefVal(TString("Multiclass"));
250  AddPreDefVal(TString("Auto"));
251 
252  // Options specific to CE
253  DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
254  DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
255  DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
256  "Determines how many processes to use for evaluation. 1 means no"
257  " parallelisation. 2 means use 2 processes. 0 means figure out the"
258  " number automatically based on the number of cpus available. Default"
259  " 1.");
260 
261  DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
262  "If given a TMVA output file will be generated for each fold. Filename will be the same as "
263  "specifed for the combined output with a _foldX suffix. (default: false)");
264 
265  DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
266  "Combines output from contained methods. If None, no combination is performed. (default None)");
267  AddPreDefVal(TString("None"));
268  AddPreDefVal(TString("Avg"));
269 }
270 
271 ////////////////////////////////////////////////////////////////////////////////
272 ///
273 
275 {
276  this->Envelope::ParseOptions();
277 
278  // Factory options
280  if (fAnalysisTypeStr == "classification")
282  else if (fAnalysisTypeStr == "regression")
284  else if (fAnalysisTypeStr == "multiclass")
286  else if (fAnalysisTypeStr == "auto")
288 
289  if (fVerbose) {
290  fCvFactoryOptions += "V:";
291  fOutputFactoryOptions += "V:";
292  } else {
293  fCvFactoryOptions += "!V:";
294  fOutputFactoryOptions += "!V:";
295  }
296 
297  fCvFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
298  fOutputFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
299 
300  fCvFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
301  fOutputFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
302 
303  if (not fDrawProgressBar) {
304  fOutputFactoryOptions += "!DrawProgressBar:";
305  }
306 
307  if (fTransformations != "") {
308  fCvFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
309  fOutputFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
310  }
311 
312  if (fCorrelations) {
313  // fCvFactoryOptions += "Correlations:";
314  fOutputFactoryOptions += "Correlations:";
315  } else {
316  // fCvFactoryOptions += "!Correlations:";
317  fOutputFactoryOptions += "!Correlations:";
318  }
319 
320  if (fROC) {
321  // fCvFactoryOptions += "ROC:";
322  fOutputFactoryOptions += "ROC:";
323  } else {
324  // fCvFactoryOptions += "!ROC:";
325  fOutputFactoryOptions += "!ROC:";
326  }
327 
328  if (fSilent) {
329  // fCvFactoryOptions += Form("Silent:");
330  fOutputFactoryOptions += Form("Silent:");
331  }
332 
333  fCvFactoryOptions += "!Correlations:!ROC:!Color:!DrawProgressBar:Silent";
334 
335  // CE specific options
336  if (fFoldFileOutput and fOutputFile == nullptr) {
337  Log() << kFATAL << "No output file given, cannot generate per fold output." << Endl;
338  }
339 
340  // Initialisations
341 
342  fFoldFactory = std::unique_ptr<TMVA::Factory>(new TMVA::Factory(fJobName, fCvFactoryOptions));
343 
344  // The fOutputFactory should always have !ModelPersistence set since we use a custom code path for this.
345  // In this case we create a special method (MethodCrossValidation) that can only be used by
346  // CrossValidation and the Reader.
347  if (fOutputFile == nullptr) {
348  fFactory = std::unique_ptr<TMVA::Factory>(new TMVA::Factory(fJobName, fOutputFactoryOptions));
349  } else {
350  fFactory = std::unique_ptr<TMVA::Factory>(new TMVA::Factory(fJobName, fOutputFile, fOutputFactoryOptions));
351  }
352 
353  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
354 }
355 
356 //_______________________________________________________________________
358 {
359  if (i != fNumFolds) {
360  fNumFolds = i;
361  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
362  fDataLoader->MakeKFoldDataSet(*fSplit.get());
363  fFoldStatus = kTRUE;
364  }
365 }
366 
367 ////////////////////////////////////////////////////////////////////////////////
368 ///
369 
371 {
372  if (splitExpr != fSplitExprString) {
373  fSplitExprString = splitExpr;
374  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
375  fDataLoader->MakeKFoldDataSet(*fSplit.get());
376  fFoldStatus = kTRUE;
377  }
378 }
379 
380 ////////////////////////////////////////////////////////////////////////////////
381 /// Evaluates each fold in turn.
382 /// - Prepares train and test data sets
383 /// - Trains method
384 /// - Evalutes on test set
385 /// - Stores the evaluation internally
386 ///
387 /// @param iFold fold to evaluate
388 ///
389 
391 {
392  TString methodName = fMethods[iMethod].GetValue<TString>("MethodName");
393  TString methodTitle = fMethods[iMethod].GetValue<TString>("MethodTitle");
394  TString methodOptions = fMethods[iMethod].GetValue<TString>("MethodOptions");
395 
396  Log() << kDEBUG << "Fold (" << methodTitle << "): " << iFold << Endl;
397 
398  // Get specific fold of dataset and setup method
399  TString foldTitle = methodTitle;
400  foldTitle += "_fold";
401  foldTitle += iFold + 1;
402 
403  // Only used if fFoldOutputFile == true
404  TFile *foldOutputFile = nullptr;
405 
406  if (fFoldFileOutput and fOutputFile != nullptr) {
407  TString path = std::string("") + gSystem->DirName(fOutputFile->GetName()) + "/" + foldTitle + ".root";
408  std::cout << "PATH: " << path << std::endl;
409  foldOutputFile = TFile::Open(path, "RECREATE");
410  fFoldFactory = std::unique_ptr<TMVA::Factory>(new TMVA::Factory(fJobName, foldOutputFile, fCvFactoryOptions));
411  }
412 
413  fDataLoader->PrepareFoldDataSet(*fSplit.get(), iFold, TMVA::Types::kTraining);
414  MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodName, foldTitle, methodOptions);
415 
416  // Train method (train method and eval train set)
418  smethod->TrainMethod();
420 
421  fFoldFactory->TestAllMethods();
422  fFoldFactory->EvaluateAllMethods();
423 
424  TMVA::CrossValidationFoldResult result(iFold);
425 
426  // Results for aggregation (ROC integral, efficiencies etc.)
428  result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
429 
430  TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
431  gr->SetLineColor(iFold + 1);
432  gr->SetLineWidth(2);
433  gr->SetTitle(foldTitle.Data());
434  result.fROC = *gr;
435 
436  result.fSig = smethod->GetSignificance();
437  result.fSep = smethod->GetSeparation();
438 
440  Double_t err;
441  result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
442  result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
443  result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
444  result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
445  result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
446  result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
447  result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
448  } else if (fAnalysisType == Types::kMulticlass) {
449  // Nothing here for now
450  }
451  }
452 
453  // Per-fold file output
454  if (fFoldFileOutput) {
455  foldOutputFile->Close();
456  }
457 
458  // Clean-up for this fold
459  {
460  smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
461  smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
462  }
463 
464  fFoldFactory->DeleteAllMethods();
465  fFoldFactory->fMethodsMap.clear();
466 
467  return result;
468 }
469 
470 ////////////////////////////////////////////////////////////////////////////////
471 /// Does training, test set evaluation and performance evaluation of using
472 /// cross-evalution.
473 ///
474 
476 {
477  // Generate K folds on given dataset
478  if (!fFoldStatus) {
479  fDataLoader->MakeKFoldDataSet(*fSplit.get());
480  fFoldStatus = kTRUE;
481  }
482 
483  fResults.reserve(fMethods.size());
484  for (UInt_t iMethod = 0; iMethod < fMethods.size(); iMethod++) {
486 
487  TString methodTypeName = fMethods[iMethod].GetValue<TString>("MethodName");
488  TString methodTitle = fMethods[iMethod].GetValue<TString>("MethodTitle");
489 
490  if (methodTypeName == "") {
491  Log() << kFATAL << "No method booked for cross-validation" << Endl;
492  }
493 
495  Log() << kINFO << "Evaluate method: " << methodTitle << Endl;
496 
497  // Process K folds
498  auto nWorkers = fNumWorkerProcs;
499  if (nWorkers == 1) {
500  // Fall back to global config
501  nWorkers = TMVA::gConfig().GetNumWorkers();
502  }
503  if (nWorkers == 1) {
504  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
505  auto fold_result = ProcessFold(iFold, iMethod);
506  result.Fill(fold_result);
507  }
508  } else {
509  ROOT::TProcessExecutor workers(nWorkers);
510  std::vector<CrossValidationFoldResult> result_vector;
511 
512  auto workItem = [this, iMethod](UInt_t iFold) {
513  return ProcessFold(iFold, iMethod);
514  };
515 
516  result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
517 
518  for (auto && fold_result : result_vector) {
519  result.Fill(fold_result);
520  }
521  }
522 
523  fResults.push_back(result);
524 
525  // Serialise the cross evaluated method
526  TString options =
527  Form("SplitExpr=%s:NumFolds=%i"
528  ":EncapsulatedMethodName=%s"
529  ":EncapsulatedMethodTypeName=%s"
530  ":OutputEnsembling=%s",
531  fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
532 
533  fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
534 
535  // Feed EventToFold mapping used when random fold assignments are used
536  // (when splitExpr="").
537  IMethod *method_interface = fFactory->GetMethod(fDataLoader.get()->GetName(), methodTitle);
538  MethodCrossValidation *method = dynamic_cast<MethodCrossValidation *>(method_interface);
539 
540  method->fEventToFoldMapping = fSplit.get()->fEventToFoldMapping;
541  }
542 
543  // Recombination of data (making sure there is data in training and testing trees).
544  fDataLoader->RecombineKFoldDataSet(*fSplit.get());
545 
546  // "Eval" on training set
547  for (UInt_t iMethod = 0; iMethod < fMethods.size(); iMethod++) {
548  TString methodTypeName = fMethods[iMethod].GetValue<TString>("MethodName");
549  TString methodTitle = fMethods[iMethod].GetValue<TString>("MethodTitle");
550 
551  IMethod *method_interface = fFactory->GetMethod(fDataLoader.get()->GetName(), methodTitle);
552  MethodCrossValidation *method = dynamic_cast<MethodCrossValidation *>(method_interface);
553 
554  if (fOutputFile) {
555  fFactory->WriteDataInformation(method->fDataSetInfo);
556  }
557 
559  method->TrainMethod();
561  }
562 
563  // Eval on Testing set
564  fFactory->TestAllMethods();
565 
566  // Calc statistics
567  fFactory->EvaluateAllMethods();
568 
569  Log() << kINFO << "Evaluation done." << Endl;
570 }
571 
572 //_______________________________________________________________________
573 const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
574 {
575  if (fResults.size() == 0)
576  Log() << kFATAL << "No cross-validation results available" << Endl;
577  return fResults;
578 }
std::vector< Double_t > fSigs
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition: TAttLine.h:43
TFile * fOutputFile
How to combine output of individual folds.
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Singleton class for Global types used by TMVA.
Definition: Types.h:73
Bool_t fFoldStatus
If true: generate output file for each fold.
std::unique_ptr< CvSplitKFolds > fSplit
float Float_t
Definition: RtypesCore.h:53
void SetTitle(const char *title="")
Set canvas title.
Definition: TCanvas.cxx:1956
std::map< UInt_t, Float_t > fROCs
Config & gConfig()
MsgLogger & Log() const
Definition: Configurable.h:122
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:47
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
Basic string class.
Definition: TString.h:131
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1100
virtual const char * DirName(const char *pathname)
Return the directory name in pathname.
Definition: TSystem.cxx:1004
bool Bool_t
Definition: RtypesCore.h:59
virtual void SetTitle(const char *title="")
Set graph title.
Definition: TGraph.cxx:2216
STL namespace.
Types::EAnalysisType fAnalysisType
UInt_t GetNumWorkers() const
Definition: Config.h:78
std::vector< Double_t > fEff10s
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Definition: TMath.h:734
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3976
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
std::vector< CrossValidationResult > fResults
DataSet * Data() const
Definition: MethodBase.h:400
void SetNumFolds(UInt_t i)
Abstract base class for all high level ml algorithms, you can book ml methods like BDT...
Definition: Envelope.h:43
std::vector< Double_t > fTrainEff01s
UInt_t fNumWorkerProcs
Number of folds to prepare.
std::vector< Double_t > fTrainEff10s
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
void Fill(CrossValidationFoldResult const &fr)
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition: TAttLine.h:40
virtual void ParseOptions()
Method to parse the internal option string.
Definition: Envelope.cxx:187
This class provides a simple interface to execute the same task multiple times in parallel...
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
std::vector< Double_t > fEff01s
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
Float_t GetROCStandardDeviation() const
const std::vector< CrossValidationResult > & GetResults() const
This is the main MVA steering class.
Definition: Factory.h:81
std::vector< Double_t > fTrainEff30s
virtual Double_t GetSignificance() const
compute significance of mean difference
TGraphErrors * gr
Definition: legend1.C:25
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Definition: DataSet.cxx:343
const Bool_t kFALSE
Definition: RtypesCore.h:88
void ParseOptions()
Method to parse the internal option string.
std::unique_ptr< Factory > fFoldFactory
The Canvas class.
Definition: TCanvas.h:31
double Double_t
Definition: RtypesCore.h:55
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
Class to perform cross validation, splitting the dataloader into folds.
CrossValidationFoldResult ProcessFold(UInt_t iFold, UInt_t iMethod)
Evaluates each fold in turn.
std::unique_ptr< Factory > fFactory
std::shared_ptr< DataLoader > fDataLoader
Booked method information.
Definition: Envelope.h:47
void AddPreDefVal(const T &)
Definition: Configurable.h:168
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
void SetSplitExpr(TString splitExpr)
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition: TNamed.cxx:74
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
std::vector< Double_t > fEffAreas
virtual void Draw(Option_t *option="")
Draw a canvas.
Definition: TCanvas.cxx:826
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fSeps
Abstract ClassifierFactory template that handles arbitrary types.
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
Definition: TPad.cxx:494
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute func (with no arguments) nTimes in parallel.
void SetSilent(Bool_t s)
Definition: Config.h:69
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
TCanvas * Draw(const TString name="CrossValidation") const
std::vector< Double_t > fEff30s
virtual Double_t GetTrainingEfficiency(const TString &)
#define c(i)
Definition: RSha256.hxx:101
TString()
TString default ctor.
Definition: TString.cxx:87
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:428
Double_t Sqrt(Double_t x)
Definition: TMath.h:690
static void EnableOutput()
Definition: MsgLogger.cxx:75
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
const Bool_t kTRUE
Definition: RtypesCore.h:87
void CheckForUnusedOptions() const
checks for unused options in option string
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< OptionMap > fMethods
Definition: Envelope.h:46
char name[80]
Definition: TGX11.cxx:109
TString fJobName
If true: dataset is prepared.
TString fOutputFactoryOptions
Number of processes to use for fold evaluation.
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:917
const char * Data() const
Definition: TString.h:364