Logo ROOT  
Reference Guide
CrossValidation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou
3// Modified: Kim Albertsson 2017
4
5/*************************************************************************
6 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
7 * All rights reserved. *
8 * *
9 * For the licensing terms see $ROOTSYS/LICENSE. *
10 * For the list of contributors see $ROOTSYS/README/CREDITS. *
11 *************************************************************************/
12
14
16#include "TMVA/Config.h"
17#include "TMVA/CvSplit.h"
18#include "TMVA/DataSet.h"
19#include "TMVA/Event.h"
20#include "TMVA/MethodBase.h"
22#include "TMVA/MsgLogger.h"
25#include "TMVA/ROCCurve.h"
26#include "TMVA/tmvaglob.h"
27#include "TMVA/Types.h"
28
29#include "TSystem.h"
30#include "TAxis.h"
31#include "TCanvas.h"
32#include "TGraph.h"
33#include "TMath.h"
34
35#include "ROOT/RMakeUnique.hxx"
36
37#include <iostream>
38#include <memory>
39
40//_______________________________________________________________________
42:fROCCurves(new TMultiGraph())
43{
44 fSigs.resize(numFolds);
45 fSeps.resize(numFolds);
46 fEff01s.resize(numFolds);
47 fEff10s.resize(numFolds);
48 fEff30s.resize(numFolds);
49 fEffAreas.resize(numFolds);
50 fTrainEff01s.resize(numFolds);
51 fTrainEff10s.resize(numFolds);
52 fTrainEff30s.resize(numFolds);
53}
54
55//_______________________________________________________________________
57{
58 fROCs=obj.fROCs;
59 fROCCurves = obj.fROCCurves;
60
61 fSigs = obj.fSigs;
62 fSeps = obj.fSeps;
63 fEff01s = obj.fEff01s;
64 fEff10s = obj.fEff10s;
65 fEff30s = obj.fEff30s;
66 fEffAreas = obj.fEffAreas;
67 fTrainEff01s = obj.fTrainEff01s;
68 fTrainEff10s = obj.fTrainEff10s;
69 fTrainEff30s = obj.fTrainEff30s;
70}
71
72//_______________________________________________________________________
74{
75 UInt_t iFold = fr.fFold;
76
77 fROCs[iFold] = fr.fROCIntegral;
78 fROCCurves->Add(dynamic_cast<TGraph *>(fr.fROC.Clone()));
79
80 fSigs[iFold] = fr.fSig;
81 fSeps[iFold] = fr.fSep;
82 fEff01s[iFold] = fr.fEff01;
83 fEff10s[iFold] = fr.fEff10;
84 fEff30s[iFold] = fr.fEff30;
85 fEffAreas[iFold] = fr.fEffArea;
86 fTrainEff01s[iFold] = fr.fTrainEff01;
87 fTrainEff10s[iFold] = fr.fTrainEff10;
88 fTrainEff30s[iFold] = fr.fTrainEff30;
89}
90
91//_______________________________________________________________________
93{
94 return fROCCurves.get();
95}
96
97////////////////////////////////////////////////////////////////////////////////
98/// \brief Generates a multigraph that contains an average ROC Curve.
99///
100/// \note You own the returned pointer.
101///
102/// \param numSamples[in] Number of samples used for generating the average ROC
103/// Curve. Avg. curve will be evaluated only at these
104/// points (using interpolation if necessary).
105///
106
108{
109 // `numSamples * increment` should equal 1.0!
110 Double_t increment = 1.0 / (numSamples-1);
111 std::vector<Double_t> x(numSamples), y(numSamples);
112
113 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
114
115 for(UInt_t iSample = 0; iSample < numSamples; iSample++) {
116 Double_t xPoint = iSample * increment;
117 Double_t rocSum = 0;
118
119 for(Int_t iGraph = 0; iGraph < rocCurveList->GetSize(); iGraph++) {
120 TGraph *foldROC = static_cast<TGraph *>(rocCurveList->At(iGraph));
121 rocSum += foldROC->Eval(xPoint);
122 }
123
124 x[iSample] = xPoint;
125 y[iSample] = rocSum/rocCurveList->GetSize();
126 }
127
128 return new TGraph(numSamples, &x[0], &y[0]);
129}
130
131//_______________________________________________________________________
133{
134 Float_t avg=0;
135 for(auto &roc : fROCs) {
136 avg+=roc.second;
137 }
138 return avg/fROCs.size();
139}
140
141//_______________________________________________________________________
143{
144 // NOTE: We are using here the unbiased estimation of the standard deviation.
145 Float_t std=0;
146 Float_t avg=GetROCAverage();
147 for(auto &roc : fROCs) {
148 std+=TMath::Power(roc.second-avg, 2);
149 }
150 return TMath::Sqrt(std/float(fROCs.size()-1.0));
151}
152
153//_______________________________________________________________________
155{
158
159 MsgLogger fLogger("CrossValidation");
160 fLogger << kHEADER << " ==== Results ====" << Endl;
161 for(auto &item:fROCs) {
162 fLogger << kINFO << Form("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
163 }
164
165 fLogger << kINFO << "------------------------" << Endl;
166 fLogger << kINFO << Form("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
167 fLogger << kINFO << Form("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
168
170}
171
172//_______________________________________________________________________
174{
175 auto *c = new TCanvas(name.Data());
176 fROCCurves->Draw("AL");
177 fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
178 fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
179 Float_t adjust=1+fROCs.size()*0.01;
180 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
181 c->SetTitle("Cross Validation ROC Curves");
182 c->Draw();
183 return c;
184}
185
186//
188{
189 TMultiGraph rocs{};
190
191 // Potentially add the folds
192 if (drawFolds) {
193 for (auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
194 TGraph * foldRocGraph = dynamic_cast<TGraph *>(foldRocObj->Clone());
195 foldRocGraph->SetLineColor(1);
196 foldRocGraph->SetLineWidth(1);
197 rocs.Add(foldRocGraph);
198 }
199 }
200
201 // Add the average roc curve
202 TGraph *avgRocGraph = GetAvgROCCurve(100);
203 avgRocGraph->SetTitle("Avg ROC Curve");
204 avgRocGraph->SetLineColor(2);
205 avgRocGraph->SetLineWidth(3);
206 rocs.Add(avgRocGraph);
207
208 // Draw
209 TCanvas *c = new TCanvas();
210
211 if (title != "") {
212 title = "Cross Validation Average ROC Curve";
213 }
214
215 rocs.SetTitle(title);
216 rocs.GetXaxis()->SetTitle("Signal Efficiency");
217 rocs.GetYaxis()->SetTitle("Background Rejection");
218 rocs.DrawClone("AL");
219
220 // Build legend
221 TLegend *leg = new TLegend();
222 TList *ROCCurveList = rocs.GetListOfGraphs();
223
224 if (drawFolds) {
225 Int_t nCurves = ROCCurveList->GetSize();
226 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(nCurves-1)),
227 "Avg ROC Curve", "l");
228 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(0)),
229 "Fold ROC Curves", "l");
230 leg->Draw();
231 } else {
232 c->BuildLegend();
233 }
234
235 // Draw Canvas
236 c->SetTitle("Cross Validation Average ROC Curve");
237 c->Draw();
238 return c;
239}
240
241/**
242* \class TMVA::CrossValidation
243* \ingroup TMVA
244* \brief
245
246Use html for explicit line breaking<br>
247Markdown links? [class reference](#reference)?
248
249
250~~~{.cpp}
251ce->BookMethod(dataloader, options);
252ce->Evaluate();
253~~~
254
255Cross-evaluation will generate a new training and a test set dynamically from
256from `K` folds. These `K` folds are generated by splitting the input training
257set. The input test set is currently ignored.
258
259This means that when you specify your DataSet you should include all events
260in your training set. One way of doing this would be the following:
261
262~~~{.cpp}
263dataloader->AddTree( signalTree, "cls1" );
264dataloader->AddTree( background, "cls2" );
265dataloader->PrepareTrainingAndTestTree( "", "", "nTest_cls1=1:nTest_cls2=1" );
266~~~
267
268## Split Expression
269See CVSplit documentation?
270
271*/
272
273////////////////////////////////////////////////////////////////////////////////
274///
275
277 TString options)
278 : TMVA::Envelope(jobName, dataloader, nullptr, options),
279 fAnalysisType(Types::kMaxAnalysisType),
280 fAnalysisTypeStr("Auto"),
281 fSplitTypeStr("Random"),
282 fCorrelations(kFALSE),
283 fCvFactoryOptions(""),
284 fDrawProgressBar(kFALSE),
285 fFoldFileOutput(kFALSE),
286 fFoldStatus(kFALSE),
287 fJobName(jobName),
288 fNumFolds(2),
289 fNumWorkerProcs(1),
290 fOutputFactoryOptions(""),
291 fOutputFile(outputFile),
292 fSilent(kFALSE),
293 fSplitExprString(""),
294 fROC(kTRUE),
295 fTransformations(""),
296 fVerbose(kFALSE),
297 fVerboseLevel(kINFO)
298{
299 InitOptions();
302}
303
304////////////////////////////////////////////////////////////////////////////////
305///
306
308 : CrossValidation(jobName, dataloader, nullptr, options)
309{
310}
311
312////////////////////////////////////////////////////////////////////////////////
313///
314
316
317////////////////////////////////////////////////////////////////////////////////
318///
319
321{
322 // Forwarding of Factory options
323 DeclareOptionRef(fSilent, "Silent",
324 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
325 "class object (default: False)");
326 DeclareOptionRef(fVerbose, "V", "Verbose flag");
327 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
328 AddPreDefVal(TString("Debug"));
329 AddPreDefVal(TString("Verbose"));
330 AddPreDefVal(TString("Info"));
331
332 DeclareOptionRef(fTransformations, "Transformations",
333 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
334 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
335 "transformations");
336
337 DeclareOptionRef(fDrawProgressBar, "DrawProgressBar", "Boolean to show draw progress bar");
338 DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation in output");
339 DeclareOptionRef(fROC, "ROC", "Boolean to show ROC in output");
340
341 TString analysisType("Auto");
342 DeclareOptionRef(fAnalysisTypeStr, "AnalysisType",
343 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
344 AddPreDefVal(TString("Classification"));
345 AddPreDefVal(TString("Regression"));
346 AddPreDefVal(TString("Multiclass"));
347 AddPreDefVal(TString("Auto"));
348
349 // Options specific to CE
350 DeclareOptionRef(fSplitTypeStr, "SplitType",
351 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
352 AddPreDefVal(TString("Deterministic"));
353 AddPreDefVal(TString("Random"));
354 AddPreDefVal(TString("RandomStratified"));
355
356 DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
357 DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
358 DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
359 "Determines how many processes to use for evaluation. 1 means no"
360 " parallelisation. 2 means use 2 processes. 0 means figure out the"
361 " number automatically based on the number of cpus available. Default"
362 " 1.");
363
364 DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
365 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
366 "specifed for the combined output with a _foldX suffix. (default: false)");
367
368 DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
369 "Combines output from contained methods. If None, no combination is performed. (default None)");
370 AddPreDefVal(TString("None"));
371 AddPreDefVal(TString("Avg"));
372}
373
374////////////////////////////////////////////////////////////////////////////////
375///
376
378{
380
381 if (fSplitTypeStr != "Deterministic" && fSplitExprString != "") {
382 Log() << kFATAL << "SplitExpr can only be used with Deterministic Splitting" << Endl;
383 }
384
385 // Factory options
386 fAnalysisTypeStr.ToLower();
387 if (fAnalysisTypeStr == "classification") {
388 fAnalysisType = Types::kClassification;
389 } else if (fAnalysisTypeStr == "regression") {
390 fAnalysisType = Types::kRegression;
391 } else if (fAnalysisTypeStr == "multiclass") {
392 fAnalysisType = Types::kMulticlass;
393 } else if (fAnalysisTypeStr == "auto") {
394 fAnalysisType = Types::kNoAnalysisType;
395 }
396
397 if (fVerbose) {
398 fCvFactoryOptions += "V:";
399 fOutputFactoryOptions += "V:";
400 } else {
401 fCvFactoryOptions += "!V:";
402 fOutputFactoryOptions += "!V:";
403 }
404
405 fCvFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
406 fOutputFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
407
408 fCvFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
409 fOutputFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
410
411 if (!fDrawProgressBar) {
412 fCvFactoryOptions += "!DrawProgressBar:";
413 fOutputFactoryOptions += "!DrawProgressBar:";
414 }
415
416 if (fTransformations != "") {
417 fCvFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
418 fOutputFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
419 }
420
421 if (fCorrelations) {
422 fCvFactoryOptions += "Correlations:";
423 fOutputFactoryOptions += "Correlations:";
424 } else {
425 fCvFactoryOptions += "!Correlations:";
426 fOutputFactoryOptions += "!Correlations:";
427 }
428
429 if (fROC) {
430 fCvFactoryOptions += "ROC:";
431 fOutputFactoryOptions += "ROC:";
432 } else {
433 fCvFactoryOptions += "!ROC:";
434 fOutputFactoryOptions += "!ROC:";
435 }
436
437 if (fSilent) {
438 fCvFactoryOptions += Form("Silent:");
439 fOutputFactoryOptions += Form("Silent:");
440 }
441
442 // CE specific options
443 if (fFoldFileOutput && fOutputFile == nullptr) {
444 Log() << kFATAL << "No output file given, cannot generate per fold output." << Endl;
445 }
446
447 // Initialisations
448
449 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
450
451 // The fOutputFactory should always have !ModelPersistence set since we use a custom code path for this.
452 // In this case we create a special method (MethodCrossValidation) that can only be used by
453 // CrossValidation and the Reader.
454 if (fOutputFile == nullptr) {
455 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
456 } else {
457 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
458 }
459
460 if(fSplitTypeStr == "Random"){
461 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kFALSE));
462 } else if(fSplitTypeStr == "RandomStratified"){
463 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kTRUE));
464 } else {
465 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
466 }
467
468}
469
470//_______________________________________________________________________
472{
473 if (i != fNumFolds) {
474 fNumFolds = i;
475 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
476 fDataLoader->MakeKFoldDataSet(*fSplit);
477 fFoldStatus = kTRUE;
478 }
479}
480
481////////////////////////////////////////////////////////////////////////////////
482///
483
485{
486 if (splitExpr != fSplitExprString) {
487 fSplitExprString = splitExpr;
488 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
489 fDataLoader->MakeKFoldDataSet(*fSplit);
490 fFoldStatus = kTRUE;
491 }
492}
493
494////////////////////////////////////////////////////////////////////////////////
495/// Evaluates each fold in turn.
496/// - Prepares train and test data sets
497/// - Trains method
498/// - Evalutes on test set
499/// - Stores the evaluation internally
500///
501/// @param iFold fold to evaluate
502///
503
505{
506 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
507 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
508 TString methodOptions = methodInfo.GetValue<TString>("MethodOptions");
509 TString foldTitle = methodTitle + TString("_fold") + TString::Format("%i", iFold + 1);
510
511 Log() << kDEBUG << "Processing " << methodTitle << " fold " << iFold << Endl;
512
513 // Only used if fFoldOutputFile == true
514 TFile *foldOutputFile = nullptr;
515
516 if (fFoldFileOutput && fOutputFile != nullptr) {
517 TString path = std::string("") + gSystem->DirName(fOutputFile->GetName()) + "/" + foldTitle + ".root";
518 foldOutputFile = TFile::Open(path, "RECREATE");
519 Log() << kINFO << "Creating fold output at:" << path << Endl;
520 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
521 }
522
523 fDataLoader->PrepareFoldDataSet(*fSplit, iFold, TMVA::Types::kTraining);
524 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
525
526 // Train method (train method and eval train set)
528 smethod->TrainMethod();
530
531 fFoldFactory->TestAllMethods();
532 fFoldFactory->EvaluateAllMethods();
533
535
536 // Results for aggregation (ROC integral, efficiencies etc.)
537 if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass) {
538 result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
539
540 TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
541 gr->SetLineColor(iFold + 1);
542 gr->SetLineWidth(2);
543 gr->SetTitle(foldTitle.Data());
544 result.fROC = *gr;
545
546 result.fSig = smethod->GetSignificance();
547 result.fSep = smethod->GetSeparation();
548
549 if (fAnalysisType == Types::kClassification) {
550 Double_t err;
551 result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
552 result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
553 result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
554 result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
555 result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
556 result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
557 result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
558 } else if (fAnalysisType == Types::kMulticlass) {
559 // Nothing here for now
560 }
561 }
562
563 // Per-fold file output
564 if (fFoldFileOutput && foldOutputFile != nullptr) {
565 foldOutputFile->Close();
566 }
567
568 // Clean-up for this fold
569 {
570 smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
571 smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
572 }
573
574 fFoldFactory->DeleteAllMethods();
575 fFoldFactory->fMethodsMap.clear();
576
577 return result;
578}
579
580////////////////////////////////////////////////////////////////////////////////
581/// Does training, test set evaluation and performance evaluation of using
582/// cross-evalution.
583///
584
586{
587 // Generate K folds on given dataset
588 if (!fFoldStatus) {
589 fDataLoader->MakeKFoldDataSet(*fSplit);
590 fFoldStatus = kTRUE;
591 }
592
593 fResults.reserve(fMethods.size());
594 for (auto & methodInfo : fMethods) {
595 CrossValidationResult result{fNumFolds};
596
597 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
598 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
599
600 if (methodTypeName == "") {
601 Log() << kFATAL << "No method booked for cross-validation" << Endl;
602 }
603
605 Log() << kINFO << Endl;
606 Log() << kINFO << Endl;
607 Log() << kINFO << "========================================" << Endl;
608 Log() << kINFO << "Processing folds for method " << methodTitle << Endl;
609 Log() << kINFO << "========================================" << Endl;
610 Log() << kINFO << Endl;
611
612 // Process K folds
613 auto nWorkers = fNumWorkerProcs;
614 if (nWorkers == 1) {
615 // Fall back to global config
616 nWorkers = TMVA::gConfig().GetNumWorkers();
617 }
618 if (nWorkers == 1) {
619 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
620 auto fold_result = ProcessFold(iFold, methodInfo);
621 result.Fill(fold_result);
622 }
623 } else {
624#ifndef _MSC_VER
625 ROOT::TProcessExecutor workers(nWorkers);
626 std::vector<CrossValidationFoldResult> result_vector;
627
628 auto workItem = [this, methodInfo](UInt_t iFold) {
629 return ProcessFold(iFold, methodInfo);
630 };
631
632 result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
633
634 for (auto && fold_result : result_vector) {
635 result.Fill(fold_result);
636 }
637#endif
638 }
639
640 fResults.push_back(result);
641
642 // Serialise the cross evaluated method
643 TString options =
644 Form("SplitExpr=%s:NumFolds=%i"
645 ":EncapsulatedMethodName=%s"
646 ":EncapsulatedMethodTypeName=%s"
647 ":OutputEnsembling=%s",
648 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
649
650 fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
651
652 // Feed EventToFold mapping used when random fold assignments are used
653 // (when splitExpr="").
654 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
655 auto *method = dynamic_cast<MethodCrossValidation *>(method_interface);
656
657 method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
658 }
659
660 Log() << kINFO << Endl;
661 Log() << kINFO << Endl;
662 Log() << kINFO << "========================================" << Endl;
663 Log() << kINFO << "Folds processed for all methods, evaluating." << Endl;
664 Log() << kINFO << "========================================" << Endl;
665 Log() << kINFO << Endl;
666
667 // Recombination of data (making sure there is data in training and testing trees).
668 fDataLoader->RecombineKFoldDataSet(*fSplit);
669
670 // "Eval" on training set
671 for (auto & methodInfo : fMethods) {
672 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
673 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
674
675 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
676 auto method = dynamic_cast<MethodCrossValidation *>(method_interface);
677
678 if (fOutputFile != nullptr) {
679 fFactory->WriteDataInformation(method->fDataSetInfo);
680 }
681
683 method->TrainMethod();
685 }
686
687 // Eval on Testing set
688 fFactory->TestAllMethods();
689
690 // Calc statistics
691 fFactory->EvaluateAllMethods();
692
693 Log() << kINFO << "Evaluation done." << Endl;
694}
695
696//_______________________________________________________________________
697const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
698{
699 if (fResults.empty()) {
700 Log() << kFATAL << "No cross-validation results available" << Endl;
701 }
702 return fResults;
703}
#define c(i)
Definition: RSha256.hxx:101
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
char name[80]
Definition: TGX11.cxx:109
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:560
This class provides a simple interface to execute the same task multiple times in parallel,...
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute func (with no arguments) nTimes in parallel.
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition: TAttLine.h:43
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition: TAttLine.h:40
The Canvas class.
Definition: TCanvas.h:31
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
Definition: TCollection.h:182
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3923
void Close(Option_t *option="") override
Close a file.
Definition: TFile.cxx:856
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
virtual void SetTitle(const char *title="")
Change (i.e.
Definition: TGraph.cxx:2312
virtual Double_t Eval(Double_t x, TSpline *spline=0, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
Definition: TGraph.cxx:871
This class displays a legend box (TPaveText) containing several legend entries.
Definition: TLegend.h:23
A doubly linked list.
Definition: TList.h:44
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition: TList.cxx:354
UInt_t GetNumWorkers() const
Definition: Config.h:74
void SetSilent(Bool_t s)
Definition: Config.h:65
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void SetNumFolds(UInt_t i)
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Definition: DataSet.cxx:343
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition: Envelope.h:47
virtual void ParseOptions()
Method to parse the internal option string.
Definition: Envelope.cxx:187
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
Interface for all concrete MVA method implementations.
Definition: IMethod.h:54
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:437
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
DataSet * Data() const
Definition: MethodBase.h:408
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
static void EnableOutput()
Definition: MsgLogger.cxx:75
class to storage options for the differents methods
Definition: OptionMap.h:36
T GetValue(const TString &key)
Definition: OptionMap.h:135
Singleton class for Global types used by TMVA.
Definition: Types.h:73
@ kCrossValidation
Definition: Types.h:110
@ kMulticlass
Definition: Types.h:130
@ kNoAnalysisType
Definition: Types.h:131
@ kClassification
Definition: Types.h:128
@ kRegression
Definition: Types.h:129
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition: TNamed.cxx:74
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2311
virtual const char * DirName(const char *pathname)
Return the directory name in pathname.
Definition: TSystem.cxx:1014
Double_t y[n]
Definition: legend1.C:17
Double_t x[n]
Definition: legend1.C:17
TGraphErrors * gr
Definition: legend1.C:25
leg
Definition: legend1.C:34
create variable transformations
Config & gConfig()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
Double_t Sqrt(Double_t x)
Definition: TMath.h:681
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Definition: TMath.h:725