Logo ROOT   6.16/01
Reference Guide
CrossValidation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou
3// Modified: Kim Albertsson 2017
4
5/*************************************************************************
6 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
7 * All rights reserved. *
8 * *
9 * For the licensing terms see $ROOTSYS/LICENSE. *
10 * For the list of contributors see $ROOTSYS/README/CREDITS. *
11 *************************************************************************/
12
14
16#include "TMVA/Config.h"
17#include "TMVA/CvSplit.h"
18#include "TMVA/DataSet.h"
19#include "TMVA/Event.h"
20#include "TMVA/MethodBase.h"
22#include "TMVA/MsgLogger.h"
25#include "TMVA/ROCCurve.h"
26#include "TMVA/tmvaglob.h"
27#include "TMVA/Types.h"
28
29#include "TSystem.h"
30#include "TAxis.h"
31#include "TCanvas.h"
32#include "TGraph.h"
33#include "TMath.h"
34
35#include "ROOT/RMakeUnique.hxx"
36
37#include <iostream>
38#include <memory>
39
40//_______________________________________________________________________
42:fROCCurves(new TMultiGraph())
43{
44 fSigs.resize(numFolds);
45 fSeps.resize(numFolds);
46 fEff01s.resize(numFolds);
47 fEff10s.resize(numFolds);
48 fEff30s.resize(numFolds);
49 fEffAreas.resize(numFolds);
50 fTrainEff01s.resize(numFolds);
51 fTrainEff10s.resize(numFolds);
52 fTrainEff30s.resize(numFolds);
53}
54
55//_______________________________________________________________________
57{
58 fROCs=obj.fROCs;
59 fROCCurves = obj.fROCCurves;
60
61 fSigs = obj.fSigs;
62 fSeps = obj.fSeps;
63 fEff01s = obj.fEff01s;
64 fEff10s = obj.fEff10s;
65 fEff30s = obj.fEff30s;
66 fEffAreas = obj.fEffAreas;
67 fTrainEff01s = obj.fTrainEff01s;
68 fTrainEff10s = obj.fTrainEff10s;
69 fTrainEff30s = obj.fTrainEff30s;
70}
71
72//_______________________________________________________________________
74{
75 UInt_t iFold = fr.fFold;
76
77 fROCs[iFold] = fr.fROCIntegral;
78 fROCCurves->Add(dynamic_cast<TGraph *>(fr.fROC.Clone()));
79
80 fSigs[iFold] = fr.fSig;
81 fSeps[iFold] = fr.fSep;
82 fEff01s[iFold] = fr.fEff01;
83 fEff10s[iFold] = fr.fEff10;
84 fEff30s[iFold] = fr.fEff30;
85 fEffAreas[iFold] = fr.fEffArea;
86 fTrainEff01s[iFold] = fr.fTrainEff01;
87 fTrainEff10s[iFold] = fr.fTrainEff10;
88 fTrainEff30s[iFold] = fr.fTrainEff30;
89}
90
91//_______________________________________________________________________
93{
94 return fROCCurves.get();
95}
96
97////////////////////////////////////////////////////////////////////////////////
98/// \brief Generates a multigraph that contains an average ROC Curve.
99///
100/// \note You own the returned pointer.
101///
102/// \param numSamples[in] Number of samples used for generating the average ROC
103/// Curve. Avg. curve will be evaluated only at these
104/// points (using interpolation if necessary).
105///
106
108{
109 // `numSamples * increment` should equal 1.0!
110 Double_t increment = 1.0 / (numSamples-1);
111 Double_t x[numSamples];
112 Double_t y[numSamples];
113
114 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
115
116 for(UInt_t iSample = 0; iSample < numSamples; iSample++) {
117 Double_t xPoint = iSample * increment;
118 Double_t rocSum = 0;
119
120 for(Int_t iGraph = 0; iGraph < rocCurveList->GetSize(); iGraph++) {
121 TGraph *foldROC = static_cast<TGraph *>(rocCurveList->At(iGraph));
122 rocSum += foldROC->Eval(xPoint);
123 }
124
125 x[iSample] = xPoint;
126 y[iSample] = rocSum/rocCurveList->GetSize();
127 }
128
129 return new TGraph(numSamples, x, y);
130}
131
132//_______________________________________________________________________
134{
135 Float_t avg=0;
136 for(auto &roc : fROCs) {
137 avg+=roc.second;
138 }
139 return avg/fROCs.size();
140}
141
142//_______________________________________________________________________
144{
145 // NOTE: We are using here the unbiased estimation of the standard deviation.
146 Float_t std=0;
147 Float_t avg=GetROCAverage();
148 for(auto &roc : fROCs) {
149 std+=TMath::Power(roc.second-avg, 2);
150 }
151 return TMath::Sqrt(std/float(fROCs.size()-1.0));
152}
153
154//_______________________________________________________________________
156{
159
160 MsgLogger fLogger("CrossValidation");
161 fLogger << kHEADER << " ==== Results ====" << Endl;
162 for(auto &item:fROCs) {
163 fLogger << kINFO << Form("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
164 }
165
166 fLogger << kINFO << "------------------------" << Endl;
167 fLogger << kINFO << Form("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
168 fLogger << kINFO << Form("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
169
171}
172
173//_______________________________________________________________________
175{
176 auto *c = new TCanvas(name.Data());
177 fROCCurves->Draw("AL");
178 fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
179 fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
180 Float_t adjust=1+fROCs.size()*0.01;
181 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
182 c->SetTitle("Cross Validation ROC Curves");
183 c->Draw();
184 return c;
185}
186
187//
189{
190 TMultiGraph rocs{};
191
192 // Potentially add the folds
193 if (drawFolds) {
194 for (auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
195 TGraph * foldRocGraph = dynamic_cast<TGraph *>(foldRocObj->Clone());
196 foldRocGraph->SetLineColor(1);
197 foldRocGraph->SetLineWidth(1);
198 rocs.Add(foldRocGraph);
199 }
200 }
201
202 // Add the average roc curve
203 TGraph *avgRocGraph = GetAvgROCCurve(100);
204 avgRocGraph->SetTitle("Avg ROC Curve");
205 avgRocGraph->SetLineColor(2);
206 avgRocGraph->SetLineWidth(3);
207 rocs.Add(avgRocGraph);
208
209 // Draw
210 TCanvas *c = new TCanvas();
211
212 if (title != "") {
213 title = "Cross Validation Average ROC Curve";
214 }
215
216 rocs.SetTitle(title);
217 rocs.GetXaxis()->SetTitle("Signal Efficiency");
218 rocs.GetYaxis()->SetTitle("Background Rejection");
219 rocs.DrawClone("AL");
220
221 // Build legend
222 TLegend *leg = new TLegend();
223 TList *ROCCurveList = rocs.GetListOfGraphs();
224
225 if (drawFolds) {
226 Int_t nCurves = ROCCurveList->GetSize();
227 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(nCurves-1)),
228 "Avg ROC Curve", "l");
229 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(0)),
230 "Fold ROC Curves", "l");
231 leg->Draw();
232 } else {
233 c->BuildLegend();
234 }
235
236 // Draw Canvas
237 c->SetTitle("Cross Validation Average ROC Curve");
238 c->Draw();
239 return c;
240}
241
242/**
243* \class TMVA::CrossValidation
244* \ingroup TMVA
245* \brief
246
247Use html for explicit line breaking<br>
248Markdown links? [class reference](#reference)?
249
250
251~~~{.cpp}
252ce->BookMethod(dataloader, options);
253ce->Evaluate();
254~~~
255
256Cross-evaluation will generate a new training and a test set dynamically from
257from `K` folds. These `K` folds are generated by splitting the input training
258set. The input test set is currently ignored.
259
260This means that when you specify your DataSet you should include all events
261in your training set. One way of doing this would be the following:
262
263~~~{.cpp}
264dataloader->AddTree( signalTree, "cls1" );
265dataloader->AddTree( background, "cls2" );
266dataloader->PrepareTrainingAndTestTree( "", "", "nTest_cls1=1:nTest_cls2=1" );
267~~~
268
269## Split Expression
270See CVSplit documentation?
271
272*/
273
274////////////////////////////////////////////////////////////////////////////////
275///
276
278 TString options)
279 : TMVA::Envelope(jobName, dataloader, nullptr, options),
280 fAnalysisType(Types::kMaxAnalysisType),
281 fAnalysisTypeStr("Auto"),
282 fSplitTypeStr("Random"),
283 fCorrelations(kFALSE),
284 fCvFactoryOptions(""),
285 fDrawProgressBar(kFALSE),
286 fFoldFileOutput(kFALSE),
287 fFoldStatus(kFALSE),
288 fJobName(jobName),
289 fNumFolds(2),
290 fNumWorkerProcs(1),
291 fOutputFactoryOptions(""),
292 fOutputFile(outputFile),
293 fSilent(kFALSE),
294 fSplitExprString(""),
295 fROC(kTRUE),
296 fTransformations(""),
297 fVerbose(kFALSE),
298 fVerboseLevel(kINFO)
299{
300 InitOptions();
303}
304
305////////////////////////////////////////////////////////////////////////////////
306///
307
309 : CrossValidation(jobName, dataloader, nullptr, options)
310{
311}
312
313////////////////////////////////////////////////////////////////////////////////
314///
315
317
318////////////////////////////////////////////////////////////////////////////////
319///
320
322{
323 // Forwarding of Factory options
324 DeclareOptionRef(fSilent, "Silent",
325 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
326 "class object (default: False)");
327 DeclareOptionRef(fVerbose, "V", "Verbose flag");
328 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
329 AddPreDefVal(TString("Debug"));
330 AddPreDefVal(TString("Verbose"));
331 AddPreDefVal(TString("Info"));
332
333 DeclareOptionRef(fTransformations, "Transformations",
334 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
335 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
336 "transformations");
337
338 DeclareOptionRef(fDrawProgressBar, "DrawProgressBar", "Boolean to show draw progress bar");
339 DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation in output");
340 DeclareOptionRef(fROC, "ROC", "Boolean to show ROC in output");
341
342 TString analysisType("Auto");
343 DeclareOptionRef(fAnalysisTypeStr, "AnalysisType",
344 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
345 AddPreDefVal(TString("Classification"));
346 AddPreDefVal(TString("Regression"));
347 AddPreDefVal(TString("Multiclass"));
348 AddPreDefVal(TString("Auto"));
349
350 // Options specific to CE
351 DeclareOptionRef(fSplitTypeStr, "SplitType",
352 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
353 AddPreDefVal(TString("Deterministic"));
354 AddPreDefVal(TString("Random"));
355 AddPreDefVal(TString("RandomStratified"));
356
357 DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
358 DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
359 DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
360 "Determines how many processes to use for evaluation. 1 means no"
361 " parallelisation. 2 means use 2 processes. 0 means figure out the"
362 " number automatically based on the number of cpus available. Default"
363 " 1.");
364
365 DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
366 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
367 "specifed for the combined output with a _foldX suffix. (default: false)");
368
369 DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
370 "Combines output from contained methods. If None, no combination is performed. (default None)");
371 AddPreDefVal(TString("None"));
372 AddPreDefVal(TString("Avg"));
373}
374
375////////////////////////////////////////////////////////////////////////////////
376///
377
379{
381
382 if (fSplitTypeStr != "Deterministic" and fSplitExprString != "") {
383 Log() << kFATAL << "SplitExpr can only be used with Deterministic Splitting" << Endl;
384 }
385
386 // Factory options
387 fAnalysisTypeStr.ToLower();
388 if (fAnalysisTypeStr == "classification") {
389 fAnalysisType = Types::kClassification;
390 } else if (fAnalysisTypeStr == "regression") {
391 fAnalysisType = Types::kRegression;
392 } else if (fAnalysisTypeStr == "multiclass") {
393 fAnalysisType = Types::kMulticlass;
394 } else if (fAnalysisTypeStr == "auto") {
395 fAnalysisType = Types::kNoAnalysisType;
396 }
397
398 if (fVerbose) {
399 fCvFactoryOptions += "V:";
400 fOutputFactoryOptions += "V:";
401 } else {
402 fCvFactoryOptions += "!V:";
403 fOutputFactoryOptions += "!V:";
404 }
405
406 fCvFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
407 fOutputFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
408
409 fCvFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
410 fOutputFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
411
412 if (not fDrawProgressBar) {
413 fOutputFactoryOptions += "!DrawProgressBar:";
414 }
415
416 if (fTransformations != "") {
417 fCvFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
418 fOutputFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
419 }
420
421 if (fCorrelations) {
422 // fCvFactoryOptions += "Correlations:";
423 fOutputFactoryOptions += "Correlations:";
424 } else {
425 // fCvFactoryOptions += "!Correlations:";
426 fOutputFactoryOptions += "!Correlations:";
427 }
428
429 if (fROC) {
430 // fCvFactoryOptions += "ROC:";
431 fOutputFactoryOptions += "ROC:";
432 } else {
433 // fCvFactoryOptions += "!ROC:";
434 fOutputFactoryOptions += "!ROC:";
435 }
436
437 if (fSilent) {
438 // fCvFactoryOptions += Form("Silent:");
439 fOutputFactoryOptions += Form("Silent:");
440 }
441
442 fCvFactoryOptions += "!Correlations:!ROC:!Color:!DrawProgressBar:Silent";
443
444 // CE specific options
445 if (fFoldFileOutput and fOutputFile == nullptr) {
446 Log() << kFATAL << "No output file given, cannot generate per fold output." << Endl;
447 }
448
449 // Initialisations
450
451 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
452
453 // The fOutputFactory should always have !ModelPersistence set since we use a custom code path for this.
454 // In this case we create a special method (MethodCrossValidation) that can only be used by
455 // CrossValidation and the Reader.
456 if (fOutputFile == nullptr) {
457 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
458 } else {
459 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
460 }
461
462 if(fSplitTypeStr == "Random"){
463 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kFALSE));
464 } else if(fSplitTypeStr == "RandomStratified"){
465 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kTRUE));
466 } else {
467 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
468 }
469
470}
471
472//_______________________________________________________________________
474{
475 if (i != fNumFolds) {
476 fNumFolds = i;
477 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
478 fDataLoader->MakeKFoldDataSet(*fSplit);
479 fFoldStatus = kTRUE;
480 }
481}
482
483////////////////////////////////////////////////////////////////////////////////
484///
485
487{
488 if (splitExpr != fSplitExprString) {
489 fSplitExprString = splitExpr;
490 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
491 fDataLoader->MakeKFoldDataSet(*fSplit);
492 fFoldStatus = kTRUE;
493 }
494}
495
496////////////////////////////////////////////////////////////////////////////////
497/// Evaluates each fold in turn.
498/// - Prepares train and test data sets
499/// - Trains method
500/// - Evalutes on test set
501/// - Stores the evaluation internally
502///
503/// @param iFold fold to evaluate
504///
505
507{
508 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
509 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
510 TString methodOptions = methodInfo.GetValue<TString>("MethodOptions");
511 TString foldTitle = methodTitle + TString("_fold") + TString::Format("%i", iFold + 1);
512
513 Log() << kDEBUG << "Processing " << methodTitle << " fold " << iFold << Endl;
514
515 // Only used if fFoldOutputFile == true
516 TFile *foldOutputFile = nullptr;
517
518 if (fFoldFileOutput and fOutputFile != nullptr) {
519 TString path = std::string("") + gSystem->DirName(fOutputFile->GetName()) + "/" + foldTitle + ".root";
520 std::cout << "PATH: " << path << std::endl;
521 foldOutputFile = TFile::Open(path, "RECREATE");
522 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
523 }
524
525 fDataLoader->PrepareFoldDataSet(*fSplit, iFold, TMVA::Types::kTraining);
526 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
527
528 // Train method (train method and eval train set)
530 smethod->TrainMethod();
532
533 fFoldFactory->TestAllMethods();
534 fFoldFactory->EvaluateAllMethods();
535
537
538 // Results for aggregation (ROC integral, efficiencies etc.)
539 if (fAnalysisType == Types::kClassification or fAnalysisType == Types::kMulticlass) {
540 result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
541
542 TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
543 gr->SetLineColor(iFold + 1);
544 gr->SetLineWidth(2);
545 gr->SetTitle(foldTitle.Data());
546 result.fROC = *gr;
547
548 result.fSig = smethod->GetSignificance();
549 result.fSep = smethod->GetSeparation();
550
551 if (fAnalysisType == Types::kClassification) {
552 Double_t err;
553 result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
554 result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
555 result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
556 result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
557 result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
558 result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
559 result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
560 } else if (fAnalysisType == Types::kMulticlass) {
561 // Nothing here for now
562 }
563 }
564
565 // Per-fold file output
566 if (fFoldFileOutput and foldOutputFile != nullptr) {
567 foldOutputFile->Close();
568 }
569
570 // Clean-up for this fold
571 {
572 smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
573 smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
574 }
575
576 fFoldFactory->DeleteAllMethods();
577 fFoldFactory->fMethodsMap.clear();
578
579 return result;
580}
581
582////////////////////////////////////////////////////////////////////////////////
583/// Does training, test set evaluation and performance evaluation of using
584/// cross-evalution.
585///
586
588{
589 // Generate K folds on given dataset
590 if (!fFoldStatus) {
591 fDataLoader->MakeKFoldDataSet(*fSplit);
592 fFoldStatus = kTRUE;
593 }
594
595 fResults.reserve(fMethods.size());
596 for (auto & methodInfo : fMethods) {
597 CrossValidationResult result{fNumFolds};
598
599 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
600 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
601
602 if (methodTypeName == "") {
603 Log() << kFATAL << "No method booked for cross-validation" << Endl;
604 }
605
607 Log() << kINFO << "Evaluate method: " << methodTitle << Endl;
608
609 // Process K folds
610 auto nWorkers = fNumWorkerProcs;
611 if (nWorkers == 1) {
612 // Fall back to global config
613 nWorkers = TMVA::gConfig().GetNumWorkers();
614 }
615 if (nWorkers == 1) {
616 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
617 auto fold_result = ProcessFold(iFold, methodInfo);
618 result.Fill(fold_result);
619 }
620 } else {
621 ROOT::TProcessExecutor workers(nWorkers);
622 std::vector<CrossValidationFoldResult> result_vector;
623
624 auto workItem = [this, methodInfo](UInt_t iFold) {
625 return ProcessFold(iFold, methodInfo);
626 };
627
628 result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
629
630 for (auto && fold_result : result_vector) {
631 result.Fill(fold_result);
632 }
633 }
634
635 fResults.push_back(result);
636
637 // Serialise the cross evaluated method
638 TString options =
639 Form("SplitExpr=%s:NumFolds=%i"
640 ":EncapsulatedMethodName=%s"
641 ":EncapsulatedMethodTypeName=%s"
642 ":OutputEnsembling=%s",
643 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
644
645 fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
646
647 // Feed EventToFold mapping used when random fold assignments are used
648 // (when splitExpr="").
649 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
650 auto *method = dynamic_cast<MethodCrossValidation *>(method_interface);
651
652 method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
653 }
654
655 // Recombination of data (making sure there is data in training and testing trees).
656 fDataLoader->RecombineKFoldDataSet(*fSplit);
657
658 // "Eval" on training set
659 for (auto & methodInfo : fMethods) {
660 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
661 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
662
663 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
664 auto method = dynamic_cast<MethodCrossValidation *>(method_interface);
665
666 if (fOutputFile != nullptr) {
667 fFactory->WriteDataInformation(method->fDataSetInfo);
668 }
669
671 method->TrainMethod();
673 }
674
675 // Eval on Testing set
676 fFactory->TestAllMethods();
677
678 // Calc statistics
679 fFactory->EvaluateAllMethods();
680
681 Log() << kINFO << "Evaluation done." << Endl;
682}
683
684//_______________________________________________________________________
685const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
686{
687 if (fResults.empty()) {
688 Log() << kFATAL << "No cross-validation results available" << Endl;
689 }
690 return fResults;
691}
#define c(i)
Definition: RSha256.hxx:101
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
This class provides a simple interface to execute the same task multiple times in parallel,...
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute func (with no arguments) nTimes in parallel.
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition: TAttLine.h:43
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition: TAttLine.h:40
The Canvas class.
Definition: TCanvas.h:31
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:182
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
virtual void Close(Option_t *option="")
Close a file.
Definition: TFile.cxx:912
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3975
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
virtual void SetTitle(const char *title="")
Set graph title.
Definition: TGraph.cxx:2232
virtual Double_t Eval(Double_t x, TSpline *spline=0, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
Definition: TGraph.cxx:865
This class displays a legend box (TPaveText) containing several legend entries.
Definition: TLegend.h:23
A doubly linked list.
Definition: TList.h:44
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:354
UInt_t GetNumWorkers() const
Definition: Config.h:78
void SetSilent(Bool_t s)
Definition: Config.h:69
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void SetNumFolds(UInt_t i)
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Definition: DataSet.cxx:343
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition: Envelope.h:44
virtual void ParseOptions()
Method to parse the internal option string.
Definition: Envelope.cxx:187
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:428
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
DataSet * Data() const
Definition: MethodBase.h:400
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
static void EnableOutput()
Definition: MsgLogger.cxx:75
class to storage options for the differents methods
Definition: OptionMap.h:36
T GetValue(const TString &key)
Definition: OptionMap.h:145
Singleton class for Global types used by TMVA.
Definition: Types.h:73
@ kCrossValidation
Definition: Types.h:110
@ kMulticlass
Definition: Types.h:130
@ kNoAnalysisType
Definition: Types.h:131
@ kClassification
Definition: Types.h:128
@ kRegression
Definition: Types.h:129
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition: TNamed.cxx:74
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2286
virtual const char * DirName(const char *pathname)
Return the directory name in pathname.
Definition: TSystem.cxx:1013
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
TGraphErrors * gr
Definition: legend1.C:25
leg
Definition: legend1.C:34
Abstract ClassifierFactory template that handles arbitrary types.
Config & gConfig()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:748
Double_t Sqrt(Double_t x)
Definition: TMath.h:679
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Definition: TMath.h:723
STL namespace.