Logo ROOT   6.14/05
Reference Guide
CrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson, Pourya Vakilipourtakalou, Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #ifndef ROOT_TMVA_CROSS_EVALUATION
13 #define ROOT_TMVA_CROSS_EVALUATION
14 
15 #include "TGraph.h"
16 #include "TMultiGraph.h"
17 #include "TString.h"
18 
19 #include "TMVA/IMethod.h"
20 #include "TMVA/Configurable.h"
21 #include "TMVA/Types.h"
22 #include "TMVA/DataSet.h"
23 #include "TMVA/Event.h"
24 #include <TMVA/Results.h>
25 #include <TMVA/Factory.h>
26 #include <TMVA/DataLoader.h>
27 #include <TMVA/OptionMap.h>
28 #include <TMVA/Envelope.h>
29 
30 /*! \class TMVA::CrossValidationResult
31  * Class to save the results of cross validation,
32  * the metric for the classification ins ROC and you can ROC curves
33  * ROC integrals, ROC average and ROC standard deviation.
34 \ingroup TMVA
35 */
36 
37 /*! \class TMVA::CrossValidation
38  * Class to perform cross validation, splitting the dataloader into folds.
39 \ingroup TMVA
40 */
41 
42 namespace TMVA {
43 
44 class CvSplitKFolds;
45 
46 using EventCollection_t = std::vector<Event *>;
47 using EventTypes_t = std::vector<Bool_t>;
48 using EventOutputs_t = std::vector<Float_t>;
49 using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
50 
52 public:
53  CrossValidationFoldResult() {} // For multi-proc serialisation
55  : fFold(iFold)
56  {}
57 
59 
62 
72 };
73 
74 // Used internally to keep per-fold aggregate statistics
75 // such as ROC curves, ROC integrals and efficiencies.
77  friend class CrossValidation;
78 
79 private:
80  std::map<UInt_t, Float_t> fROCs;
81  std::shared_ptr<TMultiGraph> fROCCurves;
82 
83  std::vector<Double_t> fSigs;
84  std::vector<Double_t> fSeps;
85  std::vector<Double_t> fEff01s;
86  std::vector<Double_t> fEff10s;
87  std::vector<Double_t> fEff30s;
88  std::vector<Double_t> fEffAreas;
89  std::vector<Double_t> fTrainEff01s;
90  std::vector<Double_t> fTrainEff10s;
91  std::vector<Double_t> fTrainEff30s;
92 
93 public:
94  CrossValidationResult(UInt_t numFolds);
96  ~CrossValidationResult() { fROCCurves = nullptr; }
97 
98  std::map<UInt_t, Float_t> GetROCValues() const { return fROCs; }
99  Float_t GetROCAverage() const;
100  Float_t GetROCStandardDeviation() const;
101  TMultiGraph *GetROCCurves(Bool_t fLegend = kTRUE);
102  void Print() const;
103 
104  TCanvas *Draw(const TString name = "CrossValidation") const;
105 
106  std::vector<Double_t> GetSigValues() const { return fSigs; }
107  std::vector<Double_t> GetSepValues() const { return fSeps; }
108  std::vector<Double_t> GetEff01Values() const { return fEff01s; }
109  std::vector<Double_t> GetEff10Values() const { return fEff10s; }
110  std::vector<Double_t> GetEff30Values() const { return fEff30s; }
111  std::vector<Double_t> GetEffAreaValues() const { return fEffAreas; }
112  std::vector<Double_t> GetTrainEff01Values() const { return fTrainEff01s; }
113  std::vector<Double_t> GetTrainEff10Values() const { return fTrainEff10s; }
114  std::vector<Double_t> GetTrainEff30Values() const { return fTrainEff30s; }
115 
116 private:
117  void Fill(CrossValidationFoldResult const & fr);
118 };
119 
120 class CrossValidation : public Envelope {
121 
122 public:
123  explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options);
124  explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options);
125  ~CrossValidation();
126 
127  void InitOptions();
128  void ParseOptions();
129 
130  void SetNumFolds(UInt_t i);
131  void SetSplitExpr(TString splitExpr);
132 
133  UInt_t GetNumFolds() { return fNumFolds; }
134  TString GetSplitExpr() { return fSplitExprString; }
135 
136  Factory &GetFactory() { return *fFactory; }
137 
138  const std::vector<CrossValidationResult> &GetResults() const;
139 
140  void Evaluate();
141 
142 private:
143  CrossValidationFoldResult ProcessFold(UInt_t iFold, UInt_t iMethod);
144 
150  Bool_t fFoldFileOutput; //! If true: generate output file for each fold
151  Bool_t fFoldStatus; //! If true: dataset is prepared
153  UInt_t fNumFolds; //! Number of folds to prepare
154  UInt_t fNumWorkerProcs; //! Number of processes to use for fold evaluation.
155  //!(Default, no parallel evaluation)
157  TString fOutputEnsembling; //! How to combine output of individual folds
161  std::vector<CrossValidationResult> fResults; //!
166 
167  std::unique_ptr<Factory> fFoldFactory;
168  std::unique_ptr<Factory> fFactory;
169  std::unique_ptr<CvSplitKFolds> fSplit;
170 
172  };
173 
174 } // namespace TMVA
175 
176 #endif // ROOT_TMVA_CROSS_EVALUATION
std::vector< Double_t > fSigs
std::vector< Double_t > GetSepValues() const
TFile * fOutputFile
How to combine output of individual folds.
std::vector< Double_t > GetEffAreaValues() const
Bool_t fFoldStatus
If true: generate output file for each fold.
std::unique_ptr< CvSplitKFolds > fSplit
float Float_t
Definition: RtypesCore.h:53
std::map< UInt_t, Float_t > fROCs
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:47
EAnalysisType
Definition: Types.h:127
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::vector< Double_t > GetEff30Values() const
Basic string class.
Definition: TString.h:131
std::vector< Double_t > GetEff10Values() const
bool Bool_t
Definition: RtypesCore.h:59
Types::EAnalysisType fAnalysisType
std::vector< Double_t > fEff10s
std::vector< Double_t > GetEff01Values() const
std::vector< Double_t > GetSigValues() const
#define ClassDef(name, id)
Definition: Rtypes.h:320
std::vector< CrossValidationResult > fResults
Abstract base class for all high level ml algorithms, you can book ml methods like BDT...
Definition: Envelope.h:43
std::vector< Double_t > fTrainEff01s
UInt_t fNumWorkerProcs
Number of folds to prepare.
std::vector< Double_t > fTrainEff10s
th1 Draw()
std::vector< Double_t > GetTrainEff10Values() const
std::vector< Double_t > fEff01s
unsigned int UInt_t
Definition: RtypesCore.h:42
This is the main MVA steering class.
Definition: Factory.h:81
std::vector< Double_t > fTrainEff30s
std::unique_ptr< Factory > fFoldFactory
The Canvas class.
Definition: TCanvas.h:31
void Print(std::ostream &os, const OptionType &opt)
double Double_t
Definition: RtypesCore.h:55
Class to perform cross validation, splitting the dataloader into folds.
std::unique_ptr< Factory > fFactory
std::vector< Double_t > GetTrainEff01Values() const
std::vector< Double_t > fEffAreas
std::vector< Double_t > fSeps
Abstract ClassifierFactory template that handles arbitrary types.
std::vector< Double_t > GetTrainEff30Values() const
A Graph is a graphics object made of two arrays X and Y with npoints each.
Definition: TGraph.h:41
std::vector< Double_t > fEff30s
const Bool_t kTRUE
Definition: RtypesCore.h:87
std::shared_ptr< TMultiGraph > fROCCurves
char name[80]
Definition: TGX11.cxx:109
TString fJobName
If true: dataset is prepared.
TString fOutputFactoryOptions
Number of processes to use for fold evaluation.
std::map< UInt_t, Float_t > GetROCValues() const