Logo ROOT   6.16/01
Reference Guide
CrossValidation.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson, Pourya Vakilipourtakalou, Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12#ifndef ROOT_TMVA_CROSS_EVALUATION
13#define ROOT_TMVA_CROSS_EVALUATION
14
15#include "TGraph.h"
16#include "TMultiGraph.h"
17#include "TString.h"
18
19#include "TMVA/IMethod.h"
20#include "TMVA/Configurable.h"
21#include "TMVA/Types.h"
22#include "TMVA/DataSet.h"
23#include "TMVA/Event.h"
24#include <TMVA/Results.h>
25#include <TMVA/Factory.h>
26#include <TMVA/DataLoader.h>
27#include <TMVA/OptionMap.h>
28#include <TMVA/Envelope.h>
29
30/*! \class TMVA::CrossValidationResult
31 * Class to save the results of cross validation,
32 * the metric for the classification ins ROC and you can ROC curves
33 * ROC integrals, ROC average and ROC standard deviation.
34\ingroup TMVA
35*/
36
37/*! \class TMVA::CrossValidation
38 * Class to perform cross validation, splitting the dataloader into folds.
39\ingroup TMVA
40*/
41
42namespace TMVA {
43
44class CvSplitKFolds;
45
46using EventCollection_t = std::vector<Event *>;
47using EventTypes_t = std::vector<Bool_t>;
48using EventOutputs_t = std::vector<Float_t>;
49using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
50
52public:
53 CrossValidationFoldResult() {} // For multi-proc serialisation
55 : fFold(iFold)
56 {}
57
59
62
72};
73
74// Used internally to keep per-fold aggregate statistics
75// such as ROC curves, ROC integrals and efficiencies.
77 friend class CrossValidation;
78
79private:
80 std::map<UInt_t, Float_t> fROCs;
81 std::shared_ptr<TMultiGraph> fROCCurves;
82
83 std::vector<Double_t> fSigs;
84 std::vector<Double_t> fSeps;
85 std::vector<Double_t> fEff01s;
86 std::vector<Double_t> fEff10s;
87 std::vector<Double_t> fEff30s;
88 std::vector<Double_t> fEffAreas;
89 std::vector<Double_t> fTrainEff01s;
90 std::vector<Double_t> fTrainEff10s;
91 std::vector<Double_t> fTrainEff30s;
92
93public:
97
98 std::map<UInt_t, Float_t> GetROCValues() const { return fROCs; }
99 Float_t GetROCAverage() const;
102 TGraph *GetAvgROCCurve(UInt_t numSamples = 100) const;
103 void Print() const;
104
105 TCanvas *Draw(const TString name = "CrossValidation") const;
106 TCanvas *DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const;
107
108 std::vector<Double_t> GetSigValues() const { return fSigs; }
109 std::vector<Double_t> GetSepValues() const { return fSeps; }
110 std::vector<Double_t> GetEff01Values() const { return fEff01s; }
111 std::vector<Double_t> GetEff10Values() const { return fEff10s; }
112 std::vector<Double_t> GetEff30Values() const { return fEff30s; }
113 std::vector<Double_t> GetEffAreaValues() const { return fEffAreas; }
114 std::vector<Double_t> GetTrainEff01Values() const { return fTrainEff01s; }
115 std::vector<Double_t> GetTrainEff10Values() const { return fTrainEff10s; }
116 std::vector<Double_t> GetTrainEff30Values() const { return fTrainEff30s; }
117
118private:
119 void Fill(CrossValidationFoldResult const & fr);
120};
121
122class CrossValidation : public Envelope {
123
124public:
125 explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options);
126 explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options);
128
129 void InitOptions();
130 void ParseOptions();
131
132 void SetNumFolds(UInt_t i);
133 void SetSplitExpr(TString splitExpr);
134
137
139
140 const std::vector<CrossValidationResult> &GetResults() const;
141
142 void Evaluate();
143
144private:
145 CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap & methodInfo);
146
153 Bool_t fFoldFileOutput; //! If true: generate output file for each fold
154 Bool_t fFoldStatus; //! If true: dataset is prepared
156 UInt_t fNumFolds; //! Number of folds to prepare
157 UInt_t fNumWorkerProcs; //! Number of processes to use for fold evaluation.
158 //!(Default, no parallel evaluation)
160 TString fOutputEnsembling; //! How to combine output of individual folds
164 std::vector<CrossValidationResult> fResults; //!
169
170 std::unique_ptr<Factory> fFoldFactory;
171 std::unique_ptr<Factory> fFactory;
172 std::unique_ptr<CvSplitKFolds> fSplit;
173
175 };
176
177} // namespace TMVA
178
179#endif // ROOT_TMVA_CROSS_EVALUATION
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassDef(name, id)
Definition: Rtypes.h:324
The Canvas class.
Definition: TCanvas.h:31
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:48
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > fSeps
std::vector< Double_t > fEff01s
CrossValidationResult(UInt_t numFolds)
std::vector< Double_t > GetTrainEff10Values() const
std::vector< Double_t > fTrainEff30s
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > GetTrainEff30Values() const
std::vector< Double_t > fSigs
std::vector< Double_t > fEff30s
void Fill(CrossValidationFoldResult const &fr)
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fEff10s
std::vector< Double_t > fTrainEff01s
std::vector< Double_t > GetEff10Values() const
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > GetTrainEff01Values() const
std::vector< Double_t > fEffAreas
std::vector< Double_t > GetEff01Values() const
TCanvas * DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const
std::vector< Double_t > GetSigValues() const
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
TGraph * GetAvgROCCurve(UInt_t numSamples=100) const
Generates a multigraph that contains an average ROC Curve.
std::map< UInt_t, Float_t > GetROCValues() const
std::vector< Double_t > GetEffAreaValues() const
TCanvas * Draw(const TString name="CrossValidation") const
std::vector< Double_t > GetSepValues() const
std::vector< Double_t > GetEff30Values() const
Class to perform cross validation, splitting the dataloader into folds.
void SetNumFolds(UInt_t i)
void ParseOptions()
Method to parse the internal option string.
const std::vector< CrossValidationResult > & GetResults() const
std::vector< CrossValidationResult > fResults
std::unique_ptr< Factory > fFoldFactory
CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
Bool_t fFoldStatus
If true: generate output file for each fold.
std::unique_ptr< CvSplitKFolds > fSplit
TFile * fOutputFile
How to combine output of individual folds.
Types::EAnalysisType fAnalysisType
void SetSplitExpr(TString splitExpr)
void Evaluate()
Does training, test set evaluation and performance evaluation of using cross-evalution.
TString fOutputFactoryOptions
Number of processes to use for fold evaluation.
std::unique_ptr< Factory > fFactory
UInt_t fNumWorkerProcs
Number of folds to prepare.
TString fJobName
If true: dataset is prepared.
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap &methodInfo)
Evaluates each fold in turn.
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition: Envelope.h:44
This is the main MVA steering class.
Definition: Factory.h:81
class to storage options for the differents methods
Definition: OptionMap.h:36
EAnalysisType
Definition: Types.h:127
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Basic string class.
Definition: TString.h:131
Abstract ClassifierFactory template that handles arbitrary types.