Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CrossValidation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou
3// Modified: Kim Albertsson 2017
4
5/*************************************************************************
6 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
7 * All rights reserved. *
8 * *
9 * For the licensing terms see $ROOTSYS/LICENSE. *
10 * For the list of contributors see $ROOTSYS/README/CREDITS. *
11 *************************************************************************/
12
14
16#include "TMVA/Config.h"
17#include "TMVA/CvSplit.h"
18#include "TMVA/DataSet.h"
19#include "TMVA/Event.h"
20#include "TMVA/MethodBase.h"
22#include "TMVA/MsgLogger.h"
25#include "TMVA/ROCCurve.h"
26#include "TMVA/Types.h"
27
28#include "TSystem.h"
29#include "TAxis.h"
30#include "TCanvas.h"
31#include "TGraph.h"
32#include "TLegend.h"
33#include "TMath.h"
34
35#include "ROOT/RMakeUnique.hxx"
36
37#include <iostream>
38#include <memory>
39
40//_______________________________________________________________________
42:fROCCurves(new TMultiGraph())
43{
44 fSigs.resize(numFolds);
45 fSeps.resize(numFolds);
46 fEff01s.resize(numFolds);
47 fEff10s.resize(numFolds);
48 fEff30s.resize(numFolds);
49 fEffAreas.resize(numFolds);
50 fTrainEff01s.resize(numFolds);
51 fTrainEff10s.resize(numFolds);
52 fTrainEff30s.resize(numFolds);
53}
54
55//_______________________________________________________________________
57{
58 fROCs=obj.fROCs;
59 fROCCurves = obj.fROCCurves;
60
61 fSigs = obj.fSigs;
62 fSeps = obj.fSeps;
63 fEff01s = obj.fEff01s;
64 fEff10s = obj.fEff10s;
65 fEff30s = obj.fEff30s;
66 fEffAreas = obj.fEffAreas;
67 fTrainEff01s = obj.fTrainEff01s;
68 fTrainEff10s = obj.fTrainEff10s;
69 fTrainEff30s = obj.fTrainEff30s;
70}
71
72//_______________________________________________________________________
74{
75 UInt_t iFold = fr.fFold;
76
77 fROCs[iFold] = fr.fROCIntegral;
78 fROCCurves->Add(dynamic_cast<TGraph *>(fr.fROC.Clone()));
79
80 fSigs[iFold] = fr.fSig;
81 fSeps[iFold] = fr.fSep;
82 fEff01s[iFold] = fr.fEff01;
83 fEff10s[iFold] = fr.fEff10;
84 fEff30s[iFold] = fr.fEff30;
85 fEffAreas[iFold] = fr.fEffArea;
86 fTrainEff01s[iFold] = fr.fTrainEff01;
87 fTrainEff10s[iFold] = fr.fTrainEff10;
88 fTrainEff30s[iFold] = fr.fTrainEff30;
89}
90
91//_______________________________________________________________________
93{
94 return fROCCurves.get();
95}
96
97////////////////////////////////////////////////////////////////////////////////
98/// \brief Generates a multigraph that contains an average ROC Curve.
99///
100/// \note You own the returned pointer.
101///
102/// \param[in] numSamples Number of samples used for generating the average ROC
103/// Curve. Avg. curve will be evaluated only at these
104/// points (using interpolation if necessary).
105///
106
108{
109 // `numSamples * increment` should equal 1.0!
110 Double_t increment = 1.0 / (numSamples-1);
111 std::vector<Double_t> x(numSamples), y(numSamples);
112
113 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
114
115 for(UInt_t iSample = 0; iSample < numSamples; iSample++) {
116 Double_t xPoint = iSample * increment;
117 Double_t rocSum = 0;
118
119 for(Int_t iGraph = 0; iGraph < rocCurveList->GetSize(); iGraph++) {
120 TGraph *foldROC = static_cast<TGraph *>(rocCurveList->At(iGraph));
121 rocSum += foldROC->Eval(xPoint);
122 }
123
124 x[iSample] = xPoint;
125 y[iSample] = rocSum/rocCurveList->GetSize();
126 }
127
128 return new TGraph(numSamples, &x[0], &y[0]);
129}
130
131//_______________________________________________________________________
133{
134 Float_t avg=0;
135 for(auto &roc : fROCs) {
136 avg+=roc.second;
137 }
138 return avg/fROCs.size();
139}
140
141//_______________________________________________________________________
143{
144 // NOTE: We are using here the unbiased estimation of the standard deviation.
145 Float_t std=0;
146 Float_t avg=GetROCAverage();
147 for(auto &roc : fROCs) {
148 std+=TMath::Power(roc.second-avg, 2);
149 }
150 return TMath::Sqrt(std/float(fROCs.size()-1.0));
151}
152
153//_______________________________________________________________________
155{
158
159 MsgLogger fLogger("CrossValidation");
160 fLogger << kHEADER << " ==== Results ====" << Endl;
161 for(auto &item:fROCs) {
162 fLogger << kINFO << Form("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
163 }
164
165 fLogger << kINFO << "------------------------" << Endl;
166 fLogger << kINFO << Form("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
167 fLogger << kINFO << Form("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
168
170}
171
172//_______________________________________________________________________
174{
175 auto *c = new TCanvas(name.Data());
176 fROCCurves->Draw("AL");
177 fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
178 fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
179 Float_t adjust=1+fROCs.size()*0.01;
180 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
181 c->SetTitle("Cross Validation ROC Curves");
182 c->Draw();
183 return c;
184}
185
186//
188{
189 // note this function will create memory leak for the TMultiGraph
190 // but it needs to be kept alive in order to display the canvas
191 TMultiGraph *rocs = new TMultiGraph();
192
193 // Potentially add the folds
194 if (drawFolds) {
195 for (auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
196 TGraph * foldRocGraph = dynamic_cast<TGraph *>(foldRocObj->Clone());
197 foldRocGraph->SetLineColor(1);
198 foldRocGraph->SetLineWidth(1);
199 rocs->Add(foldRocGraph);
200 }
201 }
202
203 // Add the average roc curve
204 TGraph *avgRocGraph = GetAvgROCCurve(100);
205 avgRocGraph->SetTitle("Avg ROC Curve");
206 avgRocGraph->SetLineColor(2);
207 avgRocGraph->SetLineWidth(3);
208 rocs->Add(avgRocGraph);
209
210 // Draw
211 TCanvas *c = new TCanvas();
212
213 if (title != "") {
214 title = "Cross Validation Average ROC Curve";
215 }
216
217 rocs->SetName("cv_rocs");
218 rocs->SetTitle(title);
219 rocs->GetXaxis()->SetTitle("Signal Efficiency");
220 rocs->GetYaxis()->SetTitle("Background Rejection");
221 rocs->DrawClone("AL");
222
223 // Build legend
224 TLegend *leg = new TLegend();
225 TList *ROCCurveList = rocs->GetListOfGraphs();
226
227 if (drawFolds) {
228 Int_t nCurves = ROCCurveList->GetSize();
229 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(nCurves-1)),
230 "Avg ROC Curve", "l");
231 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(0)),
232 "Fold ROC Curves", "l");
233 leg->Draw();
234 } else {
235 c->BuildLegend();
236 }
237
238 // Draw Canvas
239 c->SetTitle("Cross Validation Average ROC Curve");
240 c->Draw();
241 return c;
242}
243
244/**
245* \class TMVA::CrossValidation
246* \ingroup TMVA
247* \brief
248
249Use html for explicit line breaking<br>
250Markdown links? [class reference](#reference)?
251
252
253~~~{.cpp}
254ce->BookMethod(dataloader, options);
255ce->Evaluate();
256~~~
257
258Cross-evaluation will generate a new training and a test set dynamically from
259from `K` folds. These `K` folds are generated by splitting the input training
260set. The input test set is currently ignored.
261
262This means that when you specify your DataSet you should include all events
263in your training set. One way of doing this would be the following:
264
265~~~{.cpp}
266dataloader->AddTree( signalTree, "cls1" );
267dataloader->AddTree( background, "cls2" );
268dataloader->PrepareTrainingAndTestTree( "", "", "nTest_cls1=1:nTest_cls2=1" );
269~~~
270
271## Split Expression
272See CVSplit documentation?
273
274*/
275
276////////////////////////////////////////////////////////////////////////////////
277///
278
280 TString options)
281 : TMVA::Envelope(jobName, dataloader, nullptr, options),
282 fAnalysisType(Types::kMaxAnalysisType),
283 fAnalysisTypeStr("Auto"),
284 fSplitTypeStr("Random"),
285 fCorrelations(kFALSE),
286 fCvFactoryOptions(""),
287 fDrawProgressBar(kFALSE),
288 fFoldFileOutput(kFALSE),
289 fFoldStatus(kFALSE),
290 fJobName(jobName),
291 fNumFolds(2),
292 fNumWorkerProcs(1),
293 fOutputFactoryOptions(""),
294 fOutputFile(outputFile),
295 fSilent(kFALSE),
296 fSplitExprString(""),
297 fROC(kTRUE),
298 fTransformations(""),
299 fVerbose(kFALSE),
300 fVerboseLevel(kINFO)
301{
302 InitOptions();
305}
306
307////////////////////////////////////////////////////////////////////////////////
308///
309
311 : CrossValidation(jobName, dataloader, nullptr, options)
312{
313}
314
315////////////////////////////////////////////////////////////////////////////////
316///
317
319
320////////////////////////////////////////////////////////////////////////////////
321///
322
324{
325 // Forwarding of Factory options
326 DeclareOptionRef(fSilent, "Silent",
327 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
328 "class object (default: False)");
329 DeclareOptionRef(fVerbose, "V", "Verbose flag");
330 DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
331 AddPreDefVal(TString("Debug"));
332 AddPreDefVal(TString("Verbose"));
333 AddPreDefVal(TString("Info"));
334
335 DeclareOptionRef(fTransformations, "Transformations",
336 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
337 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
338 "transformations");
339
340 DeclareOptionRef(fDrawProgressBar, "DrawProgressBar", "Boolean to show draw progress bar");
341 DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation in output");
342 DeclareOptionRef(fROC, "ROC", "Boolean to show ROC in output");
343
344 TString analysisType("Auto");
345 DeclareOptionRef(fAnalysisTypeStr, "AnalysisType",
346 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
347 AddPreDefVal(TString("Classification"));
348 AddPreDefVal(TString("Regression"));
349 AddPreDefVal(TString("Multiclass"));
350 AddPreDefVal(TString("Auto"));
351
352 // Options specific to CE
353 DeclareOptionRef(fSplitTypeStr, "SplitType",
354 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
355 AddPreDefVal(TString("Deterministic"));
356 AddPreDefVal(TString("Random"));
357 AddPreDefVal(TString("RandomStratified"));
358
359 DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
360 DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
361 DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
362 "Determines how many processes to use for evaluation. 1 means no"
363 " parallelisation. 2 means use 2 processes. 0 means figure out the"
364 " number automatically based on the number of cpus available. Default"
365 " 1.");
366
367 DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
368 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
369 "specifed for the combined output with a _foldX suffix. (default: false)");
370
371 DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
372 "Combines output from contained methods. If None, no combination is performed. (default None)");
373 AddPreDefVal(TString("None"));
374 AddPreDefVal(TString("Avg"));
375}
376
377////////////////////////////////////////////////////////////////////////////////
378///
379
381{
383
384 if (fSplitTypeStr != "Deterministic" && fSplitExprString != "") {
385 Log() << kFATAL << "SplitExpr can only be used with Deterministic Splitting" << Endl;
386 }
387
388 // Factory options
389 fAnalysisTypeStr.ToLower();
390 if (fAnalysisTypeStr == "classification") {
391 fAnalysisType = Types::kClassification;
392 } else if (fAnalysisTypeStr == "regression") {
393 fAnalysisType = Types::kRegression;
394 } else if (fAnalysisTypeStr == "multiclass") {
395 fAnalysisType = Types::kMulticlass;
396 } else if (fAnalysisTypeStr == "auto") {
397 fAnalysisType = Types::kNoAnalysisType;
398 }
399
400 if (fVerbose) {
401 fCvFactoryOptions += "V:";
402 fOutputFactoryOptions += "V:";
403 } else {
404 fCvFactoryOptions += "!V:";
405 fOutputFactoryOptions += "!V:";
406 }
407
408 fCvFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
409 fOutputFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
410
411 fCvFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
412 fOutputFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
413
414 if (!fDrawProgressBar) {
415 fCvFactoryOptions += "!DrawProgressBar:";
416 fOutputFactoryOptions += "!DrawProgressBar:";
417 }
418
419 if (fTransformations != "") {
420 fCvFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
421 fOutputFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
422 }
423
424 if (fCorrelations) {
425 fCvFactoryOptions += "Correlations:";
426 fOutputFactoryOptions += "Correlations:";
427 } else {
428 fCvFactoryOptions += "!Correlations:";
429 fOutputFactoryOptions += "!Correlations:";
430 }
431
432 if (fROC) {
433 fCvFactoryOptions += "ROC:";
434 fOutputFactoryOptions += "ROC:";
435 } else {
436 fCvFactoryOptions += "!ROC:";
437 fOutputFactoryOptions += "!ROC:";
438 }
439
440 if (fSilent) {
441 fCvFactoryOptions += Form("Silent:");
442 fOutputFactoryOptions += Form("Silent:");
443 }
444
445 // CE specific options
446 if (fFoldFileOutput && fOutputFile == nullptr) {
447 Log() << kFATAL << "No output file given, cannot generate per fold output." << Endl;
448 }
449
450 // Initialisations
451
452 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
453
454 // The fOutputFactory should always have !ModelPersistence set since we use a custom code path for this.
455 // In this case we create a special method (MethodCrossValidation) that can only be used by
456 // CrossValidation and the Reader.
457 if (fOutputFile == nullptr) {
458 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
459 } else {
460 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
461 }
462
463 if(fSplitTypeStr == "Random"){
464 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kFALSE));
465 } else if(fSplitTypeStr == "RandomStratified"){
466 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kTRUE));
467 } else {
468 fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
469 }
470
471}
472
473//_______________________________________________________________________
475{
476 if (i != fNumFolds) {
477 fNumFolds = i;
478 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
479 fDataLoader->MakeKFoldDataSet(*fSplit);
480 fFoldStatus = kTRUE;
481 }
482}
483
484////////////////////////////////////////////////////////////////////////////////
485///
486
488{
489 if (splitExpr != fSplitExprString) {
490 fSplitExprString = splitExpr;
491 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
492 fDataLoader->MakeKFoldDataSet(*fSplit);
493 fFoldStatus = kTRUE;
494 }
495}
496
497////////////////////////////////////////////////////////////////////////////////
498/// Evaluates each fold in turn.
499/// - Prepares train and test data sets
500/// - Trains method
501/// - Evalutes on test set
502/// - Stores the evaluation internally
503///
504/// @param iFold fold to evaluate
505///
506
508{
509 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
510 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
511 TString methodOptions = methodInfo.GetValue<TString>("MethodOptions");
512 TString foldTitle = methodTitle + TString("_fold") + TString::Format("%i", iFold + 1);
513
514 Log() << kDEBUG << "Processing " << methodTitle << " fold " << iFold << Endl;
515
516 // Only used if fFoldOutputFile == true
517 TFile *foldOutputFile = nullptr;
518
519 if (fFoldFileOutput && fOutputFile != nullptr) {
520 TString path = gSystem->GetDirName(fOutputFile->GetName()) + "/" + foldTitle + ".root";
521 foldOutputFile = TFile::Open(path, "RECREATE");
522 Log() << kINFO << "Creating fold output at:" << path << Endl;
523 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
524 }
525
526 fDataLoader->PrepareFoldDataSet(*fSplit, iFold, TMVA::Types::kTraining);
527 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
528
529 // Train method (train method and eval train set)
531 smethod->TrainMethod();
533
534 fFoldFactory->TestAllMethods();
535 fFoldFactory->EvaluateAllMethods();
536
538
539 // Results for aggregation (ROC integral, efficiencies etc.)
540 if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass) {
541 result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
542
543 TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
544 gr->SetLineColor(iFold + 1);
545 gr->SetLineWidth(2);
546 gr->SetTitle(foldTitle.Data());
547 result.fROC = *gr;
548
549 result.fSig = smethod->GetSignificance();
550 result.fSep = smethod->GetSeparation();
551
552 if (fAnalysisType == Types::kClassification) {
553 Double_t err;
554 result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
555 result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
556 result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
557 result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
558 result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
559 result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
560 result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
561 } else if (fAnalysisType == Types::kMulticlass) {
562 // Nothing here for now
563 }
564 }
565
566 // Per-fold file output
567 if (fFoldFileOutput && foldOutputFile != nullptr) {
568 foldOutputFile->Close();
569 }
570
571 // Clean-up for this fold
572 {
573 smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
574 smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
575 }
576
577 fFoldFactory->DeleteAllMethods();
578 fFoldFactory->fMethodsMap.clear();
579
580 return result;
581}
582
583////////////////////////////////////////////////////////////////////////////////
584/// Does training, test set evaluation and performance evaluation of using
585/// cross-evalution.
586///
587
589{
590 // Generate K folds on given dataset
591 if (!fFoldStatus) {
592 fDataLoader->MakeKFoldDataSet(*fSplit);
593 fFoldStatus = kTRUE;
594 }
595
596 fResults.reserve(fMethods.size());
597 for (auto & methodInfo : fMethods) {
598 CrossValidationResult result{fNumFolds};
599
600 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
601 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
602
603 if (methodTypeName == "") {
604 Log() << kFATAL << "No method booked for cross-validation" << Endl;
605 }
606
608 Log() << kINFO << Endl;
609 Log() << kINFO << Endl;
610 Log() << kINFO << "========================================" << Endl;
611 Log() << kINFO << "Processing folds for method " << methodTitle << Endl;
612 Log() << kINFO << "========================================" << Endl;
613 Log() << kINFO << Endl;
614
615 // Process K folds
616 auto nWorkers = fNumWorkerProcs;
617 if (nWorkers == 1) {
618 // Fall back to global config
619 nWorkers = TMVA::gConfig().GetNumWorkers();
620 }
621 if (nWorkers == 1) {
622 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
623 auto fold_result = ProcessFold(iFold, methodInfo);
624 result.Fill(fold_result);
625 }
626 } else {
627#ifndef _MSC_VER
628 ROOT::TProcessExecutor workers(nWorkers);
629 std::vector<CrossValidationFoldResult> result_vector;
630
631 auto workItem = [this, methodInfo](UInt_t iFold) {
632 return ProcessFold(iFold, methodInfo);
633 };
634
635 result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
636
637 for (auto && fold_result : result_vector) {
638 result.Fill(fold_result);
639 }
640#endif
641 }
642
643 fResults.push_back(result);
644
645 // Serialise the cross evaluated method
646 TString options =
647 Form("SplitExpr=%s:NumFolds=%i"
648 ":EncapsulatedMethodName=%s"
649 ":EncapsulatedMethodTypeName=%s"
650 ":OutputEnsembling=%s",
651 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
652
653 fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
654
655 // Feed EventToFold mapping used when random fold assignments are used
656 // (when splitExpr="").
657 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
658 auto *method = dynamic_cast<MethodCrossValidation *>(method_interface);
659
660 method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
661 }
662
663 Log() << kINFO << Endl;
664 Log() << kINFO << Endl;
665 Log() << kINFO << "========================================" << Endl;
666 Log() << kINFO << "Folds processed for all methods, evaluating." << Endl;
667 Log() << kINFO << "========================================" << Endl;
668 Log() << kINFO << Endl;
669
670 // Recombination of data (making sure there is data in training and testing trees).
671 fDataLoader->RecombineKFoldDataSet(*fSplit);
672
673 // "Eval" on training set
674 for (auto & methodInfo : fMethods) {
675 TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
676 TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
677
678 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
679 auto method = dynamic_cast<MethodCrossValidation *>(method_interface);
680
681 if (fOutputFile != nullptr) {
682 fFactory->WriteDataInformation(method->fDataSetInfo);
683 }
684
686 method->TrainMethod();
688 }
689
690 // Eval on Testing set
691 fFactory->TestAllMethods();
692
693 // Calc statistics
694 fFactory->EvaluateAllMethods();
695
696 Log() << kINFO << "Evaluation done." << Endl;
697}
698
699//_______________________________________________________________________
700const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
701{
702 if (fResults.empty()) {
703 Log() << kFATAL << "No cross-validation results available" << Endl;
704 }
705 return fResults;
706}
#define c(i)
Definition RSha256.hxx:101
const Bool_t kFALSE
Definition RtypesCore.h:92
double Double_t
Definition RtypesCore.h:59
float Float_t
Definition RtypesCore.h:57
const Bool_t kTRUE
Definition RtypesCore.h:91
char name[80]
Definition TGX11.cxx:110
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition TSystem.h:559
auto Map(F func, unsigned nTimes) -> std::vector< typename std::result_of< F()>::type >
Execute a function without arguments several times.
This class provides a simple interface to execute the same task multiple times in parallel,...
A pseudo container class which is a generator of indices.
Definition TSeq.hxx:66
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition TAttLine.h:43
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition TAttLine.h:40
The Canvas class.
Definition TCanvas.h:23
virtual Int_t GetSize() const
Return the capacity of the collection, i.e.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3997
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:879
A TGraph is an object made of two arrays X and Y with npoints each.
Definition TGraph.h:41
virtual void SetTitle(const char *title="")
Change (i.e.
Definition TGraph.cxx:2339
virtual Double_t Eval(Double_t x, TSpline *spline=nullptr, Option_t *option="") const
Interpolate points in this graph at x using a TSpline.
Definition TGraph.cxx:887
This class displays a legend box (TPaveText) containing several legend entries.
Definition TLegend.h:23
A doubly linked list.
Definition TList.h:44
virtual TObject * At(Int_t idx) const
Returns the object at position idx. Returns 0 if idx is out of range.
Definition TList.cxx:357
UInt_t GetNumWorkers() const
Definition Config.h:74
void SetSilent(Bool_t s)
Definition Config.h:65
void CheckForUnusedOptions() const
checks for unused options in option string
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEffAreas
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
TCanvas * Draw(const TString name="CrossValidation") const
Class to perform cross validation, splitting the dataloader into folds.
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
void DeleteAllResults(Types::ETreeType type, Types::EAnalysisType analysistype)
Deletes all results currently in the dataset.
Definition DataSet.cxx:343
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition Envelope.h:44
virtual void ParseOptions()
Method to parse the internal option string.
Definition Envelope.cxx:182
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition Event.cxx:391
Interface for all concrete MVA method implementations.
Definition IMethod.h:53
Virtual base Class for all MVA method.
Definition MethodBase.h:111
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
virtual Double_t GetSignificance() const
compute significance of mean difference
Types::EAnalysisType GetAnalysisType() const
Definition MethodBase.h:437
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual Double_t GetTrainingEfficiency(const TString &)
DataSet * Data() const
Definition MethodBase.h:408
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
ostringstream derivative to redirect and format output
Definition MsgLogger.h:59
static void EnableOutput()
Definition MsgLogger.cxx:74
class to storage options for the differents methods
Definition OptionMap.h:34
T GetValue(const TString &key)
Definition OptionMap.h:133
Singleton class for Global types used by TMVA.
Definition Types.h:73
@ kCrossValidation
Definition Types.h:111
@ kMulticlass
Definition Types.h:131
@ kNoAnalysisType
Definition Types.h:132
@ kClassification
Definition Types.h:129
@ kRegression
Definition Types.h:130
@ kTraining
Definition Types.h:145
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition TMultiGraph.h:36
TList * GetListOfGraphs() const
Definition TMultiGraph.h:70
virtual void Add(TGraph *graph, Option_t *chopt="")
Add a new graph to the list of graphs.
TAxis * GetYaxis()
Get y axis of the graph.
TAxis * GetXaxis()
Get x axis of the graph.
virtual void SetTitle(const char *title="")
Set the title of the TNamed.
Definition TNamed.cxx:164
virtual void SetName(const char *name)
Set the name of the TNamed.
Definition TNamed.cxx:140
virtual TObject * Clone(const char *newname="") const
Make a clone of an object using the Streamer facility.
Definition TNamed.cxx:74
virtual TObject * DrawClone(Option_t *option="") const
Draw a clone of this object in the current selected pad for instance with: gROOT->SetSelectedPad(gPad...
Definition TObject.cxx:221
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2331
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
Definition TSystem.cxx:1030
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
TGraphErrors * gr
Definition legend1.C:25
leg
Definition legend1.C:34
create variable transformations
Config & gConfig()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:158
Double_t Sqrt(Double_t x)
Definition TMath.h:691
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Definition TMath.h:735