Logo ROOT  
Reference Guide
CrossValidation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou
3 // Modified: Kim Albertsson 2017
4 
5 /*************************************************************************
6  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
7  * All rights reserved. *
8  * *
9  * For the licensing terms see $ROOTSYS/LICENSE. *
10  * For the list of contributors see $ROOTSYS/README/CREDITS. *
11  *************************************************************************/
12 
13 #include "TMVA/CrossValidation.h"
14 
15 #include "TMVA/ClassifierFactory.h"
16 #include "TMVA/Config.h"
17 #include "TMVA/CvSplit.h"
18 #include "TMVA/DataSet.h"
19 #include "TMVA/Event.h"
20 #include "TMVA/MethodBase.h"
22 #include "TMVA/MsgLogger.h"
24 #include "TMVA/ResultsMulticlass.h"
25 #include "TMVA/ROCCurve.h"
26 #include "TMVA/Types.h"
27 
28 #include "TSystem.h"
29 #include "TAxis.h"
30 #include "TCanvas.h"
31 #include "TGraph.h"
32 #include "TLegend.h"
33 #include "TMath.h"
34 
35 #include "ROOT/RMakeUnique.hxx"
36 
37 #include <iostream>
38 #include <memory>
39 
40 //_______________________________________________________________________
42 :fROCCurves(new TMultiGraph())
43 {
44  fSigs.resize(numFolds);
45  fSeps.resize(numFolds);
46  fEff01s.resize(numFolds);
47  fEff10s.resize(numFolds);
48  fEff30s.resize(numFolds);
49  fEffAreas.resize(numFolds);
50  fTrainEff01s.resize(numFolds);
51  fTrainEff10s.resize(numFolds);
52  fTrainEff30s.resize(numFolds);
53 }
54 
55 //_______________________________________________________________________
57 {
58  fROCs=obj.fROCs;
59  fROCCurves = obj.fROCCurves;
60 
61  fSigs = obj.fSigs;
62  fSeps = obj.fSeps;
63  fEff01s = obj.fEff01s;
64  fEff10s = obj.fEff10s;
65  fEff30s = obj.fEff30s;
66  fEffAreas = obj.fEffAreas;
67  fTrainEff01s = obj.fTrainEff01s;
68  fTrainEff10s = obj.fTrainEff10s;
69  fTrainEff30s = obj.fTrainEff30s;
70 }
71 
72 //_______________________________________________________________________
74 {
75  UInt_t iFold = fr.fFold;
76 
77  fROCs[iFold] = fr.fROCIntegral;
78  fROCCurves->Add(dynamic_cast<TGraph *>(fr.fROC.Clone()));
79 
80  fSigs[iFold] = fr.fSig;
81  fSeps[iFold] = fr.fSep;
82  fEff01s[iFold] = fr.fEff01;
83  fEff10s[iFold] = fr.fEff10;
84  fEff30s[iFold] = fr.fEff30;
85  fEffAreas[iFold] = fr.fEffArea;
86  fTrainEff01s[iFold] = fr.fTrainEff01;
87  fTrainEff10s[iFold] = fr.fTrainEff10;
88  fTrainEff30s[iFold] = fr.fTrainEff30;
89 }
90 
91 //_______________________________________________________________________
93 {
94  return fROCCurves.get();
95 }
96 
97 ////////////////////////////////////////////////////////////////////////////////
98 /// \brief Generates a multigraph that contains an average ROC Curve.
99 ///
100 /// \note You own the returned pointer.
101 ///
102 /// \param[in] numSamples Number of samples used for generating the average ROC
103 /// Curve. Avg. curve will be evaluated only at these
104 /// points (using interpolation if necessary).
105 ///
106 
108 {
109  // `numSamples * increment` should equal 1.0!
110  Double_t increment = 1.0 / (numSamples-1);
111  std::vector<Double_t> x(numSamples), y(numSamples);
112 
113  TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
114 
115  for(UInt_t iSample = 0; iSample < numSamples; iSample++) {
116  Double_t xPoint = iSample * increment;
117  Double_t rocSum = 0;
118 
119  for(Int_t iGraph = 0; iGraph < rocCurveList->GetSize(); iGraph++) {
120  TGraph *foldROC = static_cast<TGraph *>(rocCurveList->At(iGraph));
121  rocSum += foldROC->Eval(xPoint);
122  }
123 
124  x[iSample] = xPoint;
125  y[iSample] = rocSum/rocCurveList->GetSize();
126  }
127 
128  return new TGraph(numSamples, &x[0], &y[0]);
129 }
130 
131 //_______________________________________________________________________
133 {
134  Float_t avg=0;
135  for(auto &roc : fROCs) {
136  avg+=roc.second;
137  }
138  return avg/fROCs.size();
139 }
140 
141 //_______________________________________________________________________
143 {
144  // NOTE: We are using here the unbiased estimation of the standard deviation.
145  Float_t std=0;
146  Float_t avg=GetROCAverage();
147  for(auto &roc : fROCs) {
148  std+=TMath::Power(roc.second-avg, 2);
149  }
150  return TMath::Sqrt(std/float(fROCs.size()-1.0));
151 }
152 
153 //_______________________________________________________________________
155 {
158 
159  MsgLogger fLogger("CrossValidation");
160  fLogger << kHEADER << " ==== Results ====" << Endl;
161  for(auto &item:fROCs) {
162  fLogger << kINFO << Form("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
163  }
164 
165  fLogger << kINFO << "------------------------" << Endl;
166  fLogger << kINFO << Form("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
167  fLogger << kINFO << Form("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
168 
170 }
171 
172 //_______________________________________________________________________
174 {
175  auto *c = new TCanvas(name.Data());
176  fROCCurves->Draw("AL");
177  fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
178  fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
179  Float_t adjust=1+fROCs.size()*0.01;
180  c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
181  c->SetTitle("Cross Validation ROC Curves");
182  c->Draw();
183  return c;
184 }
185 
186 //
188 {
189  // note this function will create memory leak for the TMultiGraph
190  // but it needs to be kept alive in order to display the canvas
191  TMultiGraph *rocs = new TMultiGraph();
192 
193  // Potentially add the folds
194  if (drawFolds) {
195  for (auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
196  TGraph * foldRocGraph = dynamic_cast<TGraph *>(foldRocObj->Clone());
197  foldRocGraph->SetLineColor(1);
198  foldRocGraph->SetLineWidth(1);
199  rocs->Add(foldRocGraph);
200  }
201  }
202 
203  // Add the average roc curve
204  TGraph *avgRocGraph = GetAvgROCCurve(100);
205  avgRocGraph->SetTitle("Avg ROC Curve");
206  avgRocGraph->SetLineColor(2);
207  avgRocGraph->SetLineWidth(3);
208  rocs->Add(avgRocGraph);
209 
210  // Draw
211  TCanvas *c = new TCanvas();
212 
213  if (title != "") {
214  title = "Cross Validation Average ROC Curve";
215  }
216 
217  rocs->SetName("cv_rocs");
218  rocs->SetTitle(title);
219  rocs->GetXaxis()->SetTitle("Signal Efficiency");
220  rocs->GetYaxis()->SetTitle("Background Rejection");
221  rocs->DrawClone("AL");
222 
223  // Build legend
224  TLegend *leg = new TLegend();
225  TList *ROCCurveList = rocs->GetListOfGraphs();
226 
227  if (drawFolds) {
228  Int_t nCurves = ROCCurveList->GetSize();
229  leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(nCurves-1)),
230  "Avg ROC Curve", "l");
231  leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(0)),
232  "Fold ROC Curves", "l");
233  leg->Draw();
234  } else {
235  c->BuildLegend();
236  }
237 
238  // Draw Canvas
239  c->SetTitle("Cross Validation Average ROC Curve");
240  c->Draw();
241  return c;
242 }
243 
244 /**
245 * \class TMVA::CrossValidation
246 * \ingroup TMVA
247 * \brief
248 
249 Use html for explicit line breaking<br>
250 Markdown links? [class reference](#reference)?
251 
252 
253 ~~~{.cpp}
254 ce->BookMethod(dataloader, options);
255 ce->Evaluate();
256 ~~~
257 
258 Cross-evaluation will generate a new training and a test set dynamically from
259 from `K` folds. These `K` folds are generated by splitting the input training
260 set. The input test set is currently ignored.
261 
262 This means that when you specify your DataSet you should include all events
263 in your training set. One way of doing this would be the following:
264 
265 ~~~{.cpp}
266 dataloader->AddTree( signalTree, "cls1" );
267 dataloader->AddTree( background, "cls2" );
268 dataloader->PrepareTrainingAndTestTree( "", "", "nTest_cls1=1:nTest_cls2=1" );
269 ~~~
270 
271 ## Split Expression
272 See CVSplit documentation?
273 
274 */
275 
276 ////////////////////////////////////////////////////////////////////////////////
277 ///
278 
280  TString options)
281  : TMVA::Envelope(jobName, dataloader, nullptr, options),
282  fAnalysisType(Types::kMaxAnalysisType),
283  fAnalysisTypeStr("Auto"),
284  fSplitTypeStr("Random"),
285  fCorrelations(kFALSE),
286  fCvFactoryOptions(""),
287  fDrawProgressBar(kFALSE),
288  fFoldFileOutput(kFALSE),
289  fFoldStatus(kFALSE),
290  fJobName(jobName),
291  fNumFolds(2),
292  fNumWorkerProcs(1),
293  fOutputFactoryOptions(""),
294  fOutputFile(outputFile),
295  fSilent(kFALSE),
296  fSplitExprString(""),
297  fROC(kTRUE),
298  fTransformations(""),
299  fVerbose(kFALSE),
300  fVerboseLevel(kINFO)
301 {
302  InitOptions();
305 }
306 
307 ////////////////////////////////////////////////////////////////////////////////
308 ///
309 
311  : CrossValidation(jobName, dataloader, nullptr, options)
312 {
313 }
314 
315 ////////////////////////////////////////////////////////////////////////////////
316 ///
317 
319 
320 ////////////////////////////////////////////////////////////////////////////////
321 ///
322 
324 {
325  // Forwarding of Factory options
326  DeclareOptionRef(fSilent, "Silent",
327  "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
328  "class object (default: False)");
329  DeclareOptionRef(fVerbose, "V", "Verbose flag");
330  DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
331  AddPreDefVal(TString("Debug"));
332  AddPreDefVal(TString("Verbose"));
333  AddPreDefVal(TString("Info"));
334 
335  DeclareOptionRef(fTransformations, "Transformations",
336  "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
337  "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
338  "transformations");
339 
340  DeclareOptionRef(fDrawProgressBar, "DrawProgressBar", "Boolean to show draw progress bar");
341  DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation in output");
342  DeclareOptionRef(fROC, "ROC", "Boolean to show ROC in output");
343 
344  TString analysisType("Auto");
345  DeclareOptionRef(fAnalysisTypeStr, "AnalysisType",
346  "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
347  AddPreDefVal(TString("Classification"));
348  AddPreDefVal(TString("Regression"));
349  AddPreDefVal(TString("Multiclass"));
350  AddPreDefVal(TString("Auto"));
351 
352  // Options specific to CE
353  DeclareOptionRef(fSplitTypeStr, "SplitType",
354  "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
355  AddPreDefVal(TString("Deterministic"));
356  AddPreDefVal(TString("Random"));
357  AddPreDefVal(TString("RandomStratified"));
358 
359  DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
360  DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
361  DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
362  "Determines how many processes to use for evaluation. 1 means no"
363  " parallelisation. 2 means use 2 processes. 0 means figure out the"
364  " number automatically based on the number of cpus available. Default"
365  " 1.");
366 
367  DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
368  "If given a TMVA output file will be generated for each fold. Filename will be the same as "
369  "specifed for the combined output with a _foldX suffix. (default: false)");
370 
371  DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
372  "Combines output from contained methods. If None, no combination is performed. (default None)");
373  AddPreDefVal(TString("None"));
374  AddPreDefVal(TString("Avg"));
375 }
376 
377 ////////////////////////////////////////////////////////////////////////////////
378 ///
379 
381 {
382  this->Envelope::ParseOptions();
383 
384  if (fSplitTypeStr != "Deterministic" && fSplitExprString != "") {
385  Log() << kFATAL << "SplitExpr can only be used with Deterministic Splitting" << Endl;
386  }
387 
388  // Factory options
389  fAnalysisTypeStr.ToLower();
390  if (fAnalysisTypeStr == "classification") {
391  fAnalysisType = Types::kClassification;
392  } else if (fAnalysisTypeStr == "regression") {
393  fAnalysisType = Types::kRegression;
394  } else if (fAnalysisTypeStr == "multiclass") {
395  fAnalysisType = Types::kMulticlass;
396  } else if (fAnalysisTypeStr == "auto") {
397  fAnalysisType = Types::kNoAnalysisType;
398  }
399 
400  if (fVerbose) {
401  fCvFactoryOptions += "V:";
402  fOutputFactoryOptions += "V:";
403  } else {
404  fCvFactoryOptions += "!V:";
405  fOutputFactoryOptions += "!V:";
406  }
407 
408  fCvFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
409  fOutputFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
410 
411  fCvFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
412  fOutputFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
413 
414  if (!fDrawProgressBar) {
415  fCvFactoryOptions += "!DrawProgressBar:";
416  fOutputFactoryOptions += "!DrawProgressBar:";
417  }
418 
419  if (fTransformations != "") {
420  fCvFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
421  fOutputFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
422  }
423 
424  if (fCorrelations) {
425  fCvFactoryOptions += "Correlations:";
426  fOutputFactoryOptions += "Correlations:";
427  } else {
428  fCvFactoryOptions += "!Correlations:";
429  fOutputFactoryOptions += "!Correlations:";
430  }
431 
432  if (fROC) {
433  fCvFactoryOptions += "ROC:";
434  fOutputFactoryOptions += "ROC:";
435  } else {
436  fCvFactoryOptions += "!ROC:";
437  fOutputFactoryOptions += "!ROC:";
438  }
439 
440  if (fSilent) {
441  fCvFactoryOptions += Form("Silent:");
442  fOutputFactoryOptions += Form("Silent:");
443  }
444 
445  // CE specific options
446  if (fFoldFileOutput && fOutputFile == nullptr) {
447  Log() << kFATAL << "No output file given, cannot generate per fold output." << Endl;
448  }
449 
450  // Initialisations
451 
452  fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
453 
454  // The fOutputFactory should always have !ModelPersistence set since we use a custom code path for this.
455  // In this case we create a special method (MethodCrossValidation) that can only be used by
456  // CrossValidation and the Reader.
457  if (fOutputFile == nullptr) {
458  fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
459  } else {
460  fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
461  }
462 
463  if(fSplitTypeStr == "Random"){
464  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kFALSE));
465  } else if(fSplitTypeStr == "RandomStratified"){
466  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kTRUE));
467  } else {
468  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
469  }
470 
471 }
472 
473 //_______________________________________________________________________
475 {
476  if (i != fNumFolds) {
477  fNumFolds = i;
478  fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
479  fDataLoader->MakeKFoldDataSet(*fSplit);
480  fFoldStatus = kTRUE;
481  }
482 }
483 
484 ////////////////////////////////////////////////////////////////////////////////
485 ///
486 
488 {
489  if (splitExpr != fSplitExprString) {
490  fSplitExprString = splitExpr;
491  fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
492  fDataLoader->MakeKFoldDataSet(*fSplit);
493  fFoldStatus = kTRUE;
494  }
495 }
496 
497 ////////////////////////////////////////////////////////////////////////////////
498 /// Evaluates each fold in turn.
499 /// - Prepares train and test data sets
500 /// - Trains method
501 /// - Evalutes on test set
502 /// - Stores the evaluation internally
503 ///
504 /// @param iFold fold to evaluate
505 ///
506 
508 {
509  TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
510  TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
511  TString methodOptions = methodInfo.GetValue<TString>("MethodOptions");
512  TString foldTitle = methodTitle + TString("_fold") + TString::Format("%i", iFold + 1);
513 
514  Log() << kDEBUG << "Processing " << methodTitle << " fold " << iFold << Endl;
515 
516  // Only used if fFoldOutputFile == true
517  TFile *foldOutputFile = nullptr;
518 
519  if (fFoldFileOutput && fOutputFile != nullptr) {
520  TString path = gSystem->GetDirName(fOutputFile->GetName()) + "/" + foldTitle + ".root";
521  foldOutputFile = TFile::Open(path, "RECREATE");
522  Log() << kINFO << "Creating fold output at:" << path << Endl;
523  fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
524  }
525 
526  fDataLoader->PrepareFoldDataSet(*fSplit, iFold, TMVA::Types::kTraining);
527  MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
528 
529  // Train method (train method and eval train set)
531  smethod->TrainMethod();
533 
534  fFoldFactory->TestAllMethods();
535  fFoldFactory->EvaluateAllMethods();
536 
537  TMVA::CrossValidationFoldResult result(iFold);
538 
539  // Results for aggregation (ROC integral, efficiencies etc.)
540  if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass) {
541  result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
542 
543  TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
544  gr->SetLineColor(iFold + 1);
545  gr->SetLineWidth(2);
546  gr->SetTitle(foldTitle.Data());
547  result.fROC = *gr;
548 
549  result.fSig = smethod->GetSignificance();
550  result.fSep = smethod->GetSeparation();
551 
552  if (fAnalysisType == Types::kClassification) {
553  Double_t err;
554  result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
555  result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
556  result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
557  result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
558  result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
559  result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
560  result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
561  } else if (fAnalysisType == Types::kMulticlass) {
562  // Nothing here for now
563  }
564  }
565 
566  // Per-fold file output
567  if (fFoldFileOutput && foldOutputFile != nullptr) {
568  foldOutputFile->Close();
569  }
570 
571  // Clean-up for this fold
572  {
573  smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
574  smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
575  }
576 
577  fFoldFactory->DeleteAllMethods();
578  fFoldFactory->fMethodsMap.clear();
579 
580  return result;
581 }
582 
583 ////////////////////////////////////////////////////////////////////////////////
584 /// Does training, test set evaluation and performance evaluation of using
585 /// cross-evalution.
586 ///
587 
589 {
590  // Generate K folds on given dataset
591  if (!fFoldStatus) {
592  fDataLoader->MakeKFoldDataSet(*fSplit);
593  fFoldStatus = kTRUE;
594  }
595 
596  fResults.reserve(fMethods.size());
597  for (auto & methodInfo : fMethods) {
598  CrossValidationResult result{fNumFolds};
599 
600  TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
601  TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
602 
603  if (methodTypeName == "") {
604  Log() << kFATAL << "No method booked for cross-validation" << Endl;
605  }
606 
608  Log() << kINFO << Endl;
609  Log() << kINFO << Endl;
610  Log() << kINFO << "========================================" << Endl;
611  Log() << kINFO << "Processing folds for method " << methodTitle << Endl;
612  Log() << kINFO << "========================================" << Endl;
613  Log() << kINFO << Endl;
614 
615  // Process K folds
616  auto nWorkers = fNumWorkerProcs;
617  if (nWorkers == 1) {
618  // Fall back to global config
619  nWorkers = TMVA::gConfig().GetNumWorkers();
620  }
621  if (nWorkers == 1) {
622  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
623  auto fold_result = ProcessFold(iFold, methodInfo);
624  result.Fill(fold_result);
625  }
626  } else {
627 #ifndef _MSC_VER
628  ROOT::TProcessExecutor workers(nWorkers);
629  std::vector<CrossValidationFoldResult> result_vector;
630 
631  auto workItem = [this, methodInfo](UInt_t iFold) {
632  return ProcessFold(iFold, methodInfo);
633  };
634 
635  result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
636 
637  for (auto && fold_result : result_vector) {
638  result.Fill(fold_result);
639  }
640 #endif
641  }
642 
643  fResults.push_back(result);
644 
645  // Serialise the cross evaluated method
646  TString options =
647  Form("SplitExpr=%s:NumFolds=%i"
648  ":EncapsulatedMethodName=%s"
649  ":EncapsulatedMethodTypeName=%s"
650  ":OutputEnsembling=%s",
651  fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
652 
653  fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
654 
655  // Feed EventToFold mapping used when random fold assignments are used
656  // (when splitExpr="").
657  IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
658  auto *method = dynamic_cast<MethodCrossValidation *>(method_interface);
659 
660  method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
661  }
662 
663  Log() << kINFO << Endl;
664  Log() << kINFO << Endl;
665  Log() << kINFO << "========================================" << Endl;
666  Log() << kINFO << "Folds processed for all methods, evaluating." << Endl;
667  Log() << kINFO << "========================================" << Endl;
668  Log() << kINFO << Endl;
669 
670  // Recombination of data (making sure there is data in training and testing trees).
671  fDataLoader->RecombineKFoldDataSet(*fSplit);
672 
673  // "Eval" on training set
674  for (auto & methodInfo : fMethods) {
675  TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
676  TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
677 
678  IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
679  auto method = dynamic_cast<MethodCrossValidation *>(method_interface);
680 
681  if (fOutputFile != nullptr) {
682  fFactory->WriteDataInformation(method->fDataSetInfo);
683  }
684 
686  method->TrainMethod();
688  }
689 
690  // Eval on Testing set
691  fFactory->TestAllMethods();
692 
693  // Calc statistics
694  fFactory->EvaluateAllMethods();
695 
696  Log() << kINFO << "Evaluation done." << Endl;
697 }
698 
699 //_______________________________________________________________________
700 const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
701 {
702  if (fResults.empty()) {
703  Log() << kFATAL << "No cross-validation results available" << Endl;
704  }
705  return fResults;
706 }
c
#define c(i)
Definition: RSha256.hxx:101
TMVA::OptionMap
class to storage options for the differents methods
Definition: OptionMap.h:34
TMVA::CrossValidationFoldResult::fTrainEff01
Double_t fTrainEff01
Definition: CrossValidation.h:71
TMVA::OptionMap::GetValue
T GetValue(const TString &key)
Definition: OptionMap.h:133
TMVA::CrossValidationResult::fROCCurves
std::shared_ptr< TMultiGraph > fROCCurves
Definition: CrossValidation.h:83
CrossValidation.h
TMVA::CrossValidation::CrossValidation
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
Definition: CrossValidation.cxx:310
ResultsClassification.h
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
TMVA::Types::kMulticlass
@ kMulticlass
Definition: Types.h:131
TMVA::Types::kCrossValidation
@ kCrossValidation
Definition: Types.h:111
TMVA::MethodBase::GetSeparation
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
Definition: MethodBase.cxx:2778
TMVA::CrossValidationResult::fROCs
std::map< UInt_t, Float_t > fROCs
Definition: CrossValidation.h:82
TMVA::Envelope
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition: Envelope.h:44
TGraph::SetTitle
virtual void SetTitle(const char *title="")
Change (i.e.
Definition: TGraph.cxx:2329
TMVA::MethodBase::Data
DataSet * Data() const
Definition: MethodBase.h:408
TMVA::CrossValidationFoldResult::fEff01
Double_t fEff01
Definition: CrossValidation.h:67
TMVA::MethodBase::GetAnalysisType
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:437
TNamed::SetName
virtual void SetName(const char *name)
Set the name of the TNamed.
Definition: TNamed.cxx:140
TMVA::Types::kRegression
@ kRegression
Definition: Types.h:130
TMVA::CrossValidationResult::GetROCAverage
Float_t GetROCAverage() const
Definition: CrossValidation.cxx:132
TMVA::CrossValidationResult::fEff10s
std::vector< Double_t > fEff10s
Definition: CrossValidation.h:88
TString::Data
const char * Data() const
Definition: TString.h:369
Form
char * Form(const char *fmt,...)
TMVA::CrossValidationFoldResult::fEff30
Double_t fEff30
Definition: CrossValidation.h:69
TGraph.h
TMVA::CrossValidation::~CrossValidation
~CrossValidation()
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TMVA::Types::kTesting
@ kTesting
Definition: Types.h:146
TLegend.h
TMath::Sqrt
Double_t Sqrt(Double_t x)
Definition: TMath.h:691
TMVA::CrossValidationFoldResult::fEffArea
Double_t fEffArea
Definition: CrossValidation.h:70
TMVA::MethodCrossValidation
Definition: MethodCrossValidation.h:38
TMVA::MethodBase::TrainMethod
void TrainMethod()
Definition: MethodBase.cxx:650
Float_t
float Float_t
Definition: RtypesCore.h:57
TObject::DrawClone
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad for instance with: gROOT->SetSelectedPad(gPad...
Definition: TObject.cxx:221
TFile::Open
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3995
TMVA::CrossValidationFoldResult
Definition: CrossValidation.h:53
x
Double_t x[n]
Definition: legend1.C:17
MethodCrossValidation.h
TMVA::Config::SetSilent
void SetSilent(Bool_t s)
Definition: Config.h:65
MethodBase.h
TMVA::Event::SetIsTraining
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:391
TMVA::CrossValidation::SetNumFolds
void SetNumFolds(UInt_t i)
Definition: CrossValidation.cxx:474
TAttLine::SetLineColor
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition: TAttLine.h:40
TString::Format
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2311
TCanvas.h
CvSplit.h
TString
Basic string class.
Definition: TString.h:136
TMVA::CrossValidationFoldResult::fSep
Double_t fSep
Definition: CrossValidation.h:66
TMVA::CrossValidationFoldResult::fEff10
Double_t fEff10
Definition: CrossValidation.h:68
TSystem::GetDirName
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
Definition: TSystem.cxx:1030
bool
TMVA::CrossValidationFoldResult::fSig
Double_t fSig
Definition: CrossValidation.h:65
TMVA::CrossValidationResult::fTrainEff10s
std::vector< Double_t > fTrainEff10s
Definition: CrossValidation.h:92
TMVA::CrossValidationFoldResult::fROC
TGraph fROC
Definition: CrossValidation.h:63
TMVA::CvSplitKFolds
Definition: CvSplit.h:92
TMVA::CrossValidationResult::DrawAvgROCCurve
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
Definition: CrossValidation.cxx:187
TList::At
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:357
TMultiGraph::GetXaxis
TAxis * GetXaxis()
Get x axis of the graph.
Definition: TMultiGraph.cxx:1127
TMVA::CrossValidationResult::fEff01s
std::vector< Double_t > fEff01s
Definition: CrossValidation.h:87
MsgLogger.h
TMultiGraph::GetYaxis
TAxis * GetYaxis()
Get y axis of the graph.
Definition: TMultiGraph.cxx:1139
TMVA::CrossValidation::GetResults
const std::vector< CrossValidationResult > & GetResults() const
Definition: CrossValidation.cxx:700
TSystem.h
TMVA::Configurable::CheckForUnusedOptions
void CheckForUnusedOptions() const
checks for unused options in option string
Definition: Configurable.cxx:270
TMVA::DataSet::DeleteAllResults
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Definition: DataSet.cxx:343
TMVA::CrossValidationResult::GetROCStandardDeviation
Float_t GetROCStandardDeviation() const
Definition: CrossValidation.cxx:142
gr
TGraphErrors * gr
Definition: legend1.C:25
TMVA::CrossValidationResult::fTrainEff30s
std::vector< Double_t > fTrainEff30s
Definition: CrossValidation.h:93
TMVA::CrossValidationResult::Fill
void Fill(CrossValidationFoldResult const &fr)
Definition: CrossValidation.cxx:73
TMVA::gConfig
Config & gConfig()
TMVA::CrossValidation::Evaluate
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
Definition: CrossValidation.cxx:588
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
TGraph::Eval
virtual Double_t Eval(Double_t x, TSpline *spline=0, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
Definition: TGraph.cxx:878
TString::TString
TString()
TString default ctor.
Definition: TString.cxx:87
TNamed::Clone
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition: TNamed.cxx:74
TMVA::CrossValidationFoldResult::fTrainEff30
Double_t fTrainEff30
Definition: CrossValidation.h:73
TMath::Power
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Definition: TMath.h:735
TMVA::CrossValidationResult::fTrainEff01s
std::vector< Double_t > fTrainEff01s
Definition: CrossValidation.h:91
TMVA::Types::kClassification
@ kClassification
Definition: Types.h:129
Event.h
TMVA::CrossValidation::ProcessFold
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
Definition: CrossValidation.cxx:507
TMVA::CrossValidationFoldResult::fFold
UInt_t fFold
Definition: CrossValidation.h:60
y
Double_t y[n]
Definition: legend1.C:17
TMVA::CrossValidationResult
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
Definition: CrossValidation.h:78
ROOT::TProcessExecutor::Map
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute func (with no arguments) nTimes in parallel.
Definition: TProcessExecutor.hxx:98
TMVA::MethodBase
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
ROCCurve.h
TMVA::Types
Singleton class for Global types used by TMVA.
Definition: Types.h:73
Types.h
TNamed::SetTitle
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition: TNamed.cxx:164
TMVA::CrossValidation::SetSplitExpr
void SetSplitExpr(TString splitExpr)
Definition: CrossValidation.cxx:487
TFile
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:54
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Config.h
unsigned int
TMVA::Envelope::ParseOptions
virtual void ParseOptions()
Method to parse the internal option string.
Definition: Envelope.cxx:182
TMVA::IMethod
Interface for all concrete MVA method implementations.
Definition: IMethod.h:53
TMVA::CrossValidationResult::CrossValidationResult
CrossValidationResult(UInt_t numFolds)
Definition: CrossValidation.cxx:41
TMVA::Types::kTraining
@ kTraining
Definition: Types.h:145
gSystem
R__EXTERN TSystem * gSystem
Definition: TSystem.h:559
TMultiGraph
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:36
TAttLine::SetLineWidth
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition: TAttLine.h:43
TMultiGraph::GetListOfGraphs
TList * GetListOfGraphs() const
Definition: TMultiGraph.h:70
TMVA::MsgLogger::EnableOutput
static void EnableOutput()
Definition: MsgLogger.cxx:74
TMVA::MethodBase::GetSignificance
virtual Double_t GetSignificance() const
compute significance of mean difference
Definition: MethodBase.cxx:2765
TMVA::Config::GetNumWorkers
UInt_t GetNumWorkers() const
Definition: Config.h:74
Double_t
double Double_t
Definition: RtypesCore.h:59
TGraph
A TGraph is an object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
TMVA::MsgLogger
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
ROOT::TProcessExecutor
This class provides a simple interface to execute the same task multiple times in parallel,...
Definition: TProcessExecutor.hxx:35
TCanvas
The Canvas class.
Definition: TCanvas.h:23
TCollection::GetSize
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:182
TMVA::MethodCrossValidation::fEventToFoldMapping
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
Definition: MethodCrossValidation.h:117
TAxis.h
TMVA::CrossValidation::ParseOptions
void ParseOptions()
Method to parse the internal option string.
Definition: CrossValidation.cxx:380
TFile::Close
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:876
TMultiGraph::Add
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
Definition: TMultiGraph.cxx:451
TMVA::CrossValidationResult::Print
void Print() const
Definition: CrossValidation.cxx:154
leg
leg
Definition: legend1.C:34
TMVA::MethodBase::GetEfficiency
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
Definition: MethodBase.cxx:2291
name
char name[80]
Definition: TGX11.cxx:110
TMVA::CrossValidation::InitOptions
void InitOptions()
Definition: CrossValidation.cxx:323
TMVA::CrossValidationFoldResult::fTrainEff10
Double_t fTrainEff10
Definition: CrossValidation.h:72
TMVA::MethodBase::GetTrainingEfficiency
virtual Double_t GetTrainingEfficiency(const TString &)
Definition: MethodBase.cxx:2517
RMakeUnique.hxx
ResultsMulticlass.h
TLegend
This class displays a legend box (TPaveText) containing several legend entries.
Definition: TLegend.h:23
ROOT::TSeq
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
TMVA::CrossValidationResult::GetROCCurves
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
Definition: CrossValidation.cxx:92
ClassifierFactory.h
TMVA::CrossValidationResult::fEffAreas
std::vector< Double_t > fEffAreas
Definition: CrossValidation.h:90
TMVA::CrossValidationResult::fEff30s
std::vector< Double_t > fEff30s
Definition: CrossValidation.h:89
TMVA::CrossValidationResult::GetAvgROCCurve
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
Definition: CrossValidation.cxx:107
TMVA::CrossValidationResult::fSeps
std::vector< Double_t > fSeps
Definition: CrossValidation.h:86
TMVA::CrossValidationFoldResult::fROCIntegral
Float_t fROCIntegral
Definition: CrossValidation.h:62
DataSet.h
TMVA::CrossValidationResult::fSigs
std::vector< Double_t > fSigs
Definition: CrossValidation.h:85
TMVA::CrossValidationResult::Draw
TCanvas * Draw(const TString name="CrossValidation") const
Definition: CrossValidation.cxx:173
TMVA::Types::kNoAnalysisType
@ kNoAnalysisType
Definition: Types.h:132
TList
A doubly linked list.
Definition: TList.h:44
TMVA::CrossValidation
Class to perform cross validation, splitting the dataloader into folds.
Definition: CrossValidation.h:124
TMath.h
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
int
TMVA::DataLoader
Definition: DataLoader.h:50