Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CrossValidation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou
3// Modified: Kim Albertsson 2017
4
5/*************************************************************************
6 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
7 * All rights reserved. *
8 * *
9 * For the licensing terms see $ROOTSYS/LICENSE. *
10 * For the list of contributors see $ROOTSYS/README/CREDITS. *
11 *************************************************************************/
12
14
16#include "TMVA/Config.h"
17#include "TMVA/CvSplit.h"
18#include "TMVA/DataSet.h"
19#include "TMVA/Event.h"
20#include "TMVA/MethodBase.h"
22#include "TMVA/MsgLogger.h"
25#include "TMVA/ROCCurve.h"
26#include "TMVA/Types.h"
27
28#include "TSystem.h"
29#include "TAxis.h"
30#include "TCanvas.h"
31#include "TGraph.h"
32#include "TLegend.h"
33#include "TMath.h"
34
35#include <iostream>
36#include <memory>
37
38//_______________________________________________________________________
40:fROCCurves(new TMultiGraph())
41{
42 fSigs.resize(numFolds);
43 fSeps.resize(numFolds);
44 fEff01s.resize(numFolds);
45 fEff10s.resize(numFolds);
46 fEff30s.resize(numFolds);
47 fEffAreas.resize(numFolds);
48 fTrainEff01s.resize(numFolds);
49 fTrainEff10s.resize(numFolds);
50 fTrainEff30s.resize(numFolds);
51}
52
53//_______________________________________________________________________
55{
56 fROCs=obj.fROCs;
57 fROCCurves = obj.fROCCurves;
58
59 fSigs = obj.fSigs;
60 fSeps = obj.fSeps;
61 fEff01s = obj.fEff01s;
62 fEff10s = obj.fEff10s;
63 fEff30s = obj.fEff30s;
64 fEffAreas = obj.fEffAreas;
65 fTrainEff01s = obj.fTrainEff01s;
66 fTrainEff10s = obj.fTrainEff10s;
67 fTrainEff30s = obj.fTrainEff30s;
68}
69
70//_______________________________________________________________________
72{
73 UInt_t iFold = fr.fFold;
74
75 fROCs[iFold] = fr.fROCIntegral;
76 fROCCurves->Add(dynamic_cast<TGraph *>(fr.fROC.Clone()));
77
78 fSigs[iFold] = fr.fSig;
79 fSeps[iFold] = fr.fSep;
80 fEff01s[iFold] = fr.fEff01;
81 fEff10s[iFold] = fr.fEff10;
82 fEff30s[iFold] = fr.fEff30;
83 fEffAreas[iFold] = fr.fEffArea;
84 fTrainEff01s[iFold] = fr.fTrainEff01;
85 fTrainEff10s[iFold] = fr.fTrainEff10;
86 fTrainEff30s[iFold] = fr.fTrainEff30;
87}
88
89//_______________________________________________________________________
91{
92 return fROCCurves.get();
93}
94
95////////////////////////////////////////////////////////////////////////////////
96/// \brief Generates a multigraph that contains an average ROC Curve.
97///
98/// \note You own the returned pointer.
99///
100/// \param[in] numSamples Number of samples used for generating the average ROC
101/// Curve. Avg. curve will be evaluated only at these
102/// points (using interpolation if necessary).
103///
104
106{
107 // `numSamples * increment` should equal 1.0!
108 Double_t increment = 1.0 / (numSamples-1);
109 std::vector<Double_t> x(numSamples), y(numSamples);
110
111 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
112
113 for(UInt_t iSample = 0; iSample < numSamples; iSample++) {
114 Double_t xPoint = iSample * increment;
115 Double_t rocSum = 0;
116
117 for(Int_t iGraph = 0; iGraph < rocCurveList->GetSize(); iGraph++) {
118 TGraph *foldROC = static_cast<TGraph *>(rocCurveList->At(iGraph));
119 rocSum += foldROC->Eval(xPoint);
120 }
121
122 x[iSample] = xPoint;
123 y[iSample] = rocSum/rocCurveList->GetSize();
124 }
125
126 return new TGraph(numSamples, &x[0], &y[0]);
127}
128
129//_______________________________________________________________________
131{
132 Float_t avg=0;
133 for(auto &roc : fROCs) {
134 avg+=roc.second;
135 }
136 return avg/fROCs.size();
137}
138
139//_______________________________________________________________________
141{
142 // NOTE: We are using here the unbiased estimation of the standard deviation.
143 Float_t std=0;
144 Float_t avg=GetROCAverage();
145 for(auto &roc : fROCs) {
146 std+=TMath::Power(roc.second-avg, 2);
147 }
148 return TMath::Sqrt(std/float(fROCs.size()-1.0));
149}
150
151//_______________________________________________________________________
153{
156
157 MsgLogger fLogger("CrossValidation");
158 fLogger << kHEADER << " ==== Results ====" << Endl;
159 for(auto &item:fROCs) {
160 fLogger << kINFO << TString::Format("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
161 }
162
163 fLogger << kINFO << "------------------------" << Endl;
164 fLogger << kINFO << TString::Format("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
165 fLogger << kINFO << TString::Format("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
166
168}
169
170//_______________________________________________________________________
172{
173 auto *c = new TCanvas(name.Data());
174 fROCCurves->Draw("AL");
175 fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
176 fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
177 Float_t adjust=1+fROCs.size()*0.01;
178 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
179 c->SetTitle("Cross Validation ROC Curves");
180 c->Draw();
181 return c;
182}
183
184//
186{
187 // note this function will create memory leak for the TMultiGraph
188 // but it needs to be kept alive in order to display the canvas
189 TMultiGraph *rocs = new TMultiGraph();
190
191 // Potentially add the folds
192 if (drawFolds) {
193 for (auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
194 TGraph * foldRocGraph = dynamic_cast<TGraph *>(foldRocObj->Clone());
195 foldRocGraph->SetLineColor(1);
196 foldRocGraph->SetLineWidth(1);
197 rocs->Add(foldRocGraph);
198 }
199 }
200
201 // Add the average roc curve
202 TGraph *avgRocGraph = GetAvgROCCurve(100);
203 avgRocGraph->SetTitle("Avg ROC Curve");
204 avgRocGraph->SetLineColor(2);
205 avgRocGraph->SetLineWidth(3);
206 rocs->Add(avgRocGraph);
207
208 // Draw
209 TCanvas *c = new TCanvas();
210
211 if (title != "") {
212 title = "Cross Validation Average ROC Curve";
213 }
214
215 rocs->SetName("cv_rocs");
216 rocs->SetTitle(title);
217 rocs->GetXaxis()->SetTitle("Signal Efficiency");
218 rocs->GetYaxis()->SetTitle("Background Rejection");
219 rocs->DrawClone("AL");
220
221 // Build legend
222 TLegend *leg = new TLegend();
223 TList *ROCCurveList = rocs->GetListOfGraphs();
224
225 if (drawFolds) {
226 Int_t nCurves = ROCCurveList->GetSize();
227 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(nCurves-1)),
228 "Avg ROC Curve", "l");
229 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(0)),
230 "Fold ROC Curves", "l");
231 leg->Draw();
232 } else {
233 c->BuildLegend();
234 }
235
236 // Draw Canvas
237 c->SetTitle("Cross Validation Average ROC Curve");
238 c->Draw();
239 return c;
240}
241
242/**
243* \class TMVA::CrossValidation
244* \ingroup TMVA
245* \brief
246
247Use html for explicit line breaking<br>
248Markdown links? [class reference](#reference)?
249
250
251~~~{.cpp}
252ce->BookMethod(dataloader, options);
253ce->Evaluate();
254~~~
255
256Cross-evaluation will generate a new training and a test set dynamically from
257from `K` folds. These `K` folds are generated by splitting the input training
258set. The input test set is currently ignored.
259
260This means that when you specify your DataSet you should include all events
261in your training set. One way of doing this would be the following:
262
263~~~{.cpp}
264dataloader->AddTree( signalTree, "cls1" );
265dataloader->AddTree( background, "cls2" );
266dataloader->PrepareTrainingAndTestTree( "", "", "nTest_cls1=1:nTest_cls2=1" );
267~~~
268
269## Split Expression
270See CVSplit documentation?
271
272*/
273
274////////////////////////////////////////////////////////////////////////////////
275///
276
278 TString options)
279 : TMVA::Envelope(jobName, dataloader, nullptr, options),
280 fAnalysisType(Types::kMaxAnalysisType),
281 fAnalysisTypeStr("Auto"),
282 fSplitTypeStr("Random"),
283 fCorrelations(kFALSE),
284 fCvFactoryOptions(""),
285 fDrawProgressBar(kFALSE),
286 fFoldFileOutput(kFALSE),
287 fFoldStatus(kFALSE),
288 fJobName(jobName),
289 fNumFolds(2),
290 fNumWorkerProcs(1),
291 fOutputFactoryOptions(""),
292 fOutputFile(outputFile),
293 fSilent(kFALSE),
294 fSplitExprString(""),
295 fROC(kTRUE),
296 fTransformations(""),
297 fVerbose(kFALSE),
298 fVerboseLevel(kINFO)
299{
300 InitOptions();
303}
304
305////////////////////////////////////////////////////////////////////////////////
306///
307
309 : CrossValidation(jobName, dataloader, nullptr, options)
310{
311}
312
313////////////////////////////////////////////////////////////////////////////////
314///
315
317
318////////////////////////////////////////////////////////////////////////////////
319///
320
322{
323 // Forwarding of Factory options
324 DeclareOptionRef(fSilent, "Silent",
325 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
326 "class object (default: False)");
327 DeclareOptionRef(fVerbose, "V", "Verbose flag");
328 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
329 AddPreDefVal(TString("Debug"));
330 AddPreDefVal(TString("Verbose"));
331 AddPreDefVal(TString("Info"));
332
333 DeclareOptionRef(fTransformations, "Transformations",
334 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
335 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
336 "transformations");
337
338 DeclareOptionRef(fDrawProgressBar, "DrawProgressBar", "Boolean to show draw progress bar");
339 DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation in output");
340 DeclareOptionRef(fROC, "ROC", "Boolean to show ROC in output");
341
342 TString analysisType("Auto");
343 DeclareOptionRef(fAnalysisTypeStr, "AnalysisType",
344 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
345 AddPreDefVal(TString("Classification"));
346 AddPreDefVal(TString("Regression"));
347 AddPreDefVal(TString("Multiclass"));
348 AddPreDefVal(TString("Auto"));
349
350 // Options specific to CE
351 DeclareOptionRef(fSplitTypeStr, "SplitType",
352 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
353 AddPreDefVal(TString("Deterministic"));
354 AddPreDefVal(TString("Random"));
355 AddPreDefVal(TString("RandomStratified"));
356
357 DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
358 DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
359 DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
360 "Determines how many processes to use for evaluation. 1 means no"
361 " parallelisation. 2 means use 2 processes. 0 means figure out the"
362 " number automatically based on the number of cpus available. Default"
363 " 1.");
364
365 DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
366 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
367 "specifed for the combined output with a _foldX suffix. (default: false)");
368
369 DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
370 "Combines output from contained methods. If None, no combination is performed. (default None)");
371 AddPreDefVal(TString("None"));
372 AddPreDefVal(TString("Avg"));
373}
374
375////////////////////////////////////////////////////////////////////////////////
376///
377
379{
381
382 if (fSplitTypeStr != "Deterministic" && fSplitExprString != "") {
383 Log() << kFATAL << "SplitExpr can only be used with Deterministic Splitting" << Endl;
384 }
385
386 // Factory options
387 fAnalysisTypeStr.ToLower();
388 if (fAnalysisTypeStr == "classification") {
389 fAnalysisType = Types::kClassification;
390 } else if (fAnalysisTypeStr == "regression") {
391 fAnalysisType = Types::kRegression;
392 } else if (fAnalysisTypeStr == "multiclass") {
393 fAnalysisType = Types::kMulticlass;
394 } else if (fAnalysisTypeStr == "auto") {
395 fAnalysisType = Types::kNoAnalysisType;
396 }
397
398 if (fVerbose) {
399 fCvFactoryOptions += "V:";
400 fOutputFactoryOptions += "V:";
401 } else {
402 fCvFactoryOptions += "!V:";
403 fOutputFactoryOptions += "!V:";
404 }
405
406 fCvFactoryOptions += TString::Format("VerboseLevel=%s:", fVerboseLevel.Data());
407 fOutputFactoryOptions += TString::Format("VerboseLevel=%s:", fVerboseLevel.Data());
408
409 fCvFactoryOptions += TString::Format("AnalysisType=%s:", fAnalysisTypeStr.Data());
410 fOutputFactoryOptions += TString::Format("AnalysisType=%s:", fAnalysisTypeStr.Data());
411
412 if (!fDrawProgressBar) {
413 fCvFactoryOptions += "!DrawProgressBar:";
414 fOutputFactoryOptions += "!DrawProgressBar:";
415 }
416
417 if (fTransformations != "") {
418 fCvFactoryOptions += TString::Format("Transformations=%s:", fTransformations.Data());
419 fOutputFactoryOptions += TString::Format("Transformations=%s:", fTransformations.Data());
420 }
421
422 if (fCorrelations) {
423 fCvFactoryOptions += "Correlations:";
424 fOutputFactoryOptions += "Correlations:";
425 } else {
426 fCvFactoryOptions += "!Correlations:";
427 fOutputFactoryOptions += "!Correlations:";
428 }
429
430 if (fROC) {
431 fCvFactoryOptions += "ROC:";
432 fOutputFactoryOptions += "ROC:";
433 } else {
434 fCvFactoryOptions += "!ROC:";
435 fOutputFactoryOptions += "!ROC:";
436 }
437
438 if (fSilent) {
439 fCvFactoryOptions += "Silent:";
440 fOutputFactoryOptions += "Silent:";
441 }
442
443 // CE specific options
444 if (fFoldFileOutput && fOutputFile == nullptr) {
445 Log() << kFATAL << "No output file given, cannot generate per fold output." << Endl;
446 }
447
448 // Initialisations
449
450 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
451
452 // The fOutputFactory should always have !ModelPersistence set since we use a custom code path for this.
453 // In this case we create a special method (MethodCrossValidation) that can only be used by
454 // CrossValidation and the Reader.
455 if (fOutputFile == nullptr) {
456 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
457 } else {
458 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
459 }
460
461 if(fSplitTypeStr == "Random"){
462 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kFALSE));
463 } else if(fSplitTypeStr == "RandomStratified"){
464 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kTRUE));
465 } else {
466 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
467 }
468
469}
470
471//_______________________________________________________________________
473{
474 if (i != fNumFolds) {
475 fNumFolds = i;
476 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
477 fDataLoader->MakeKFoldDataSet(*fSplit);
478 fFoldStatus = kTRUE;
479 }
480}
481
482////////////////////////////////////////////////////////////////////////////////
483///
484
486{
487 if (splitExpr != fSplitExprString) {
488 fSplitExprString = splitExpr;
489 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
490 fDataLoader->MakeKFoldDataSet(*fSplit);
491 fFoldStatus = kTRUE;
492 }
493}
494
495////////////////////////////////////////////////////////////////////////////////
496/// Evaluates each fold in turn.
497/// - Prepares train and test data sets
498/// - Trains method
499/// - Evalutes on test set
500/// - Stores the evaluation internally
501///
502/// @param iFold fold to evaluate
503/// @param methodInfo method metadata
504///
505
507{
508 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
509 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
510 TString methodOptions = methodInfo.GetValue<TString>("MethodOptions");
511 TString foldTitle = methodTitle + TString("_fold") + TString::Format("%i", iFold + 1);
512
513 Log() << kDEBUG << "Processing " << methodTitle << " fold " << iFold << Endl;
514
515 // Only used if fFoldOutputFile == true
516 TFile *foldOutputFile = nullptr;
517
518 if (fFoldFileOutput && fOutputFile != nullptr) {
519 TString path = gSystem->GetDirName(fOutputFile->GetName()) + "/" + foldTitle + ".root";
520 foldOutputFile = TFile::Open(path, "RECREATE");
521 Log() << kINFO << "Creating fold output at:" << path << Endl;
522 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
523 }
524
525 fDataLoader->PrepareFoldDataSet(*fSplit, iFold, TMVA::Types::kTraining);
526 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
527
528 // Train method (train method and eval train set)
530 smethod->TrainMethod();
532
533 fFoldFactory->TestAllMethods();
534 fFoldFactory->EvaluateAllMethods();
535
537
538 // Results for aggregation (ROC integral, efficiencies etc.)
539 if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass) {
540 result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
541
542 TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
543 gr->SetLineColor(iFold + 1);
544 gr->SetLineWidth(2);
545 gr->SetTitle(foldTitle.Data());
546 result.fROC = *gr;
547
548 result.fSig = smethod->GetSignificance();
549 result.fSep = smethod->GetSeparation();
550
551 if (fAnalysisType == Types::kClassification) {
552 Double_t err;
553 result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
554 result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
555 result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
556 result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
557 result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
558 result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
559 result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
560 } else if (fAnalysisType == Types::kMulticlass) {
561 // Nothing here for now
562 }
563 }
564
565 // Per-fold file output
566 if (fFoldFileOutput && foldOutputFile != nullptr) {
567 foldOutputFile->Close();
568 }
569
570 // Clean-up for this fold
571 {
572 smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
573 smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
574 }
575
576 fFoldFactory->DeleteAllMethods();
577 fFoldFactory->fMethodsMap.clear();
578
579 return result;
580}
581
582////////////////////////////////////////////////////////////////////////////////
583/// Does training, test set evaluation and performance evaluation of using
584/// cross-evalution.
585///
586
588{
589 // Generate K folds on given dataset
590 if (!fFoldStatus) {
591 fDataLoader->MakeKFoldDataSet(*fSplit);
592 fFoldStatus = kTRUE;
593 }
594
595 fResults.reserve(fMethods.size());
596 for (auto & methodInfo : fMethods) {
597 CrossValidationResult result{fNumFolds};
598
599 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
600 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
601
602 if (methodTypeName == "") {
603 Log() << kFATAL << "No method booked for cross-validation" << Endl;
604 }
605
607 Log() << kINFO << Endl;
608 Log() << kINFO << Endl;
609 Log() << kINFO << "========================================" << Endl;
610 Log() << kINFO << "Processing folds for method " << methodTitle << Endl;
611 Log() << kINFO << "========================================" << Endl;
612 Log() << kINFO << Endl;
613
614 // Process K folds
615 auto nWorkers = fNumWorkerProcs;
616 if (nWorkers == 1) {
617 // Fall back to global config
618 nWorkers = TMVA::gConfig().GetNumWorkers();
619 }
620 if (nWorkers == 1) {
621 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
622 auto fold_result = ProcessFold(iFold, methodInfo);
623 result.Fill(fold_result);
624 }
625 } else {
626#ifndef _MSC_VER
627 ROOT::TProcessExecutor workers(nWorkers);
628 std::vector<CrossValidationFoldResult> result_vector;
629
630 auto workItem = [this, methodInfo](UInt_t iFold) {
631 return ProcessFold(iFold, methodInfo);
632 };
633
634 result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
635
636 for (auto && fold_result : result_vector) {
637 result.Fill(fold_result);
638 }
639#endif
640 }
641
642 fResults.push_back(result);
643
644 // Serialise the cross evaluated method
645 TString options =
646 TString::Format("SplitExpr=%s:NumFolds=%i"
647 ":EncapsulatedMethodName=%s"
648 ":EncapsulatedMethodTypeName=%s"
649 ":OutputEnsembling=%s",
650 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
651
652 fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
653
654 // Feed EventToFold mapping used when random fold assignments are used
655 // (when splitExpr="").
656 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
657 auto *method = dynamic_cast<MethodCrossValidation *>(method_interface);
658
659 method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
660 }
661
662 Log() << kINFO << Endl;
663 Log() << kINFO << Endl;
664 Log() << kINFO << "========================================" << Endl;
665 Log() << kINFO << "Folds processed for all methods, evaluating." << Endl;
666 Log() << kINFO << "========================================" << Endl;
667 Log() << kINFO << Endl;
668
669 // Recombination of data (making sure there is data in training and testing trees).
670 fDataLoader->RecombineKFoldDataSet(*fSplit);
671
672 // "Eval" on training set
673 for (auto & methodInfo : fMethods) {
674 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
675 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
676
677 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
678 auto method = dynamic_cast<MethodCrossValidation *>(method_interface);
679
680 if (fOutputFile != nullptr) {
681 fFactory->WriteDataInformation(method->fDataSetInfo);
682 }
683
685 method->TrainMethod();
687 }
688
689 // Eval on Testing set
690 fFactory->TestAllMethods();
691
692 // Calc statistics
693 fFactory->EvaluateAllMethods();
694
695 Log() << kINFO << "Evaluation done." << Endl;
696}
697
698//_______________________________________________________________________
699const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
700{
701 if (fResults.empty()) {
702 Log() << kFATAL << "No cross-validation results available" << Endl;
703 }
704 return fResults;
705}
#define c(i)
Definition RSha256.hxx:101
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
char name[80]
Definition TGX11.cxx:110
R__EXTERN TSystem * gSystem
Definition TSystem.h:561
auto Map(F func, unsigned nTimes) -> std::vector< InvokeResult_t< F > >
Execute a function without arguments several times.
This class provides a simple interface to execute the same task multiple times in parallel,...
A pseudo container class which is a generator of indices.
Definition TSeq.hxx:67
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition TAttLine.h:43
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:40
The Canvas class.
Definition TCanvas.h:23
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4089
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:950
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
virtual Double_t Eval(Double_t x, TSpline *spline=nullptr, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
Definition TGraph.cxx:953
void SetTitle(const char *title="") override
Change (i.e.
Definition TGraph.cxx:2397
This class displays a legend box (TPaveText) containing several legend entries.
Definition TLegend.h:23
A doubly linked list.
Definition TList.h:38
TObject * At(Int_t idx) const override
Returns the object at position idx. Returns 0 if idx is out of range.
Definition TList.cxx:355
UInt_t GetNumWorkers() const
Definition Config.h:72
void SetSilent(Bool_t s)
Definition Config.h:63
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Definition DataSet.cxx:343
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition Envelope.h:44
virtual void ParseOptions()
Method to parse the internal option string.
Definition Envelope.cxx:182
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:399
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:437
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
DataSet * Data() const
Definition MethodBase.h:409
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
static void EnableOutput()
Definition MsgLogger.cxx:68
class to storage options for the differents methods
Definition OptionMap.h:34
T GetValue(const TString &key)
Definition OptionMap.h:133
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kCrossValidation
Definition Types.h:109
@ kMulticlass
Definition Types.h:129
@ kNoAnalysisType
Definition Types.h:130
@ kClassification
Definition Types.h:127
@ kRegression
Definition Types.h:128
@ kTraining
Definition Types.h:143
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:34
TList * GetListOfGraphs() const
Definition TMultiGraph.h:68
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
TObject * Clone(const char *newname="") const override
Make a clone of an object using the Streamer facility.
Definition TNamed.cxx:74
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:164
virtual void SetName(const char *name)
Set the name of the TNamed.
Definition TNamed.cxx:140
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad with: gROOT->SetSelectedPad(c1).
Definition TObject.cxx:299
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:376
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
Definition TSystem.cxx:1032
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
TGraphErrors * gr
Definition legend1.C:25
leg
Definition legend1.C:34
create variable transformations
Config & gConfig()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:662
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.
Definition TMath.h:721