Logo ROOT   6.08/07
Reference Guide
CrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou. 2016
3 
4 #ifndef ROOT_TMVA_CrossValidation
5 #define ROOT_TMVA_CrossValidation
6 
7 #ifndef ROOT_TString
8 #include "TString.h"
9 #endif
10 
11 #ifndef ROOT_TMultiGraph
12 #include "TMultiGraph.h"
13 #endif
14 
15 #ifndef ROOT_TMVA_IMethod
16 #include "TMVA/IMethod.h"
17 #endif
18 
19 #ifndef ROOT_TMVA_Configurable
20 #include "TMVA/Configurable.h"
21 #endif
22 
23 #ifndef ROOT_TMVA_Types
24 #include "TMVA/Types.h"
25 #endif
26 
27 #ifndef ROOT_TMVA_DataSet
28 #include "TMVA/DataSet.h"
29 #endif
30 
31 #ifndef ROOT_TMVA_Event
32 #include "TMVA/Event.h"
33 #endif
34 
35 #ifndef ROOT_TMVA_Results
36 #include <TMVA/Results.h>
37 #endif
38 
39 #ifndef ROOT_TMVA_Factory
40 #include <TMVA/Factory.h>
41 #endif
42 
43 #ifndef ROOT_TMVA_DataLoader
44 #include <TMVA/DataLoader.h>
45 #endif
46 
47 #ifndef ROOT_TMVA_OptionMap
48 #include <TMVA/OptionMap.h>
49 #endif
50 
51 #ifndef ROOT_TMVA_Envelope
52 #include <TMVA/Envelope.h>
53 #endif
54 
55 namespace TMVA {
56 
58  friend class CrossValidation;
59 
60  private:
61  std::map<UInt_t,Float_t> fROCs;
62  std::shared_ptr<TMultiGraph> fROCCurves;
63 
64  std::vector<Double_t> fSigs;
65  std::vector<Double_t> fSeps;
66  std::vector<Double_t> fEff01s;
67  std::vector<Double_t> fEff10s;
68  std::vector<Double_t> fEff30s;
69  std::vector<Double_t> fEffAreas;
70  std::vector<Double_t> fTrainEff01s;
71  std::vector<Double_t> fTrainEff10s;
72  std::vector<Double_t> fTrainEff30s;
73 
74  public:
77  ~CrossValidationResult(){fROCCurves=nullptr;}
78 
79  std::map<UInt_t,Float_t> GetROCValues(){return fROCs;}
80  Float_t GetROCAverage() const;
83  void Print() const ;
84 
85  TCanvas* Draw(const TString name="CrossValidation") const;
86 
87  std::vector<Double_t> GetSigValues() {return fSigs;}
88  std::vector<Double_t> GetSepValues() {return fSeps;}
89  std::vector<Double_t> GetEff01Values() {return fEff01s;}
90  std::vector<Double_t> GetEff10Values() {return fEff10s;}
91  std::vector<Double_t> GetEff30Values() {return fEff30s;}
92  std::vector<Double_t> GetEffAreaValues() {return fEffAreas;}
93  std::vector<Double_t> GetTrainEff01Values() {return fTrainEff01s;}
94  std::vector<Double_t> GetTrainEff10Values() {return fTrainEff10s;}
95  std::vector<Double_t> GetTrainEff30Values() {return fTrainEff30s;}
96  };
97 
98 
99  class CrossValidation : public Envelope {
103  public:
104  explicit CrossValidation(DataLoader *loader);
105  ~CrossValidation();
106 
107  void SetNumFolds(UInt_t i);
108  UInt_t GetNumFolds() {return fNumFolds;}
109 
110  virtual void Evaluate();
111 // void EvaluateFold(UInt_t fold);//used in ParallelExecution
112 
113  const CrossValidationResult& GetResults() const;
114 
115  private:
116  std::unique_ptr<Factory> fClassifier;
118  };
119 
120 } // namespace TMVA
121 
122 #endif // ROOT_TMVA_CrossValidation
std::vector< Double_t > fSigs
float Float_t
Definition: RtypesCore.h:53
std::map< UInt_t, Float_t > GetROCValues()
CrossValidationResult fResults
std::vector< Double_t > GetEff01Values()
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:37
Basic string class.
Definition: TString.h:137
bool Bool_t
Definition: RtypesCore.h:59
std::unique_ptr< Factory > fClassifier
std::vector< Double_t > GetEff10Values()
std::vector< Double_t > GetSigValues()
std::vector< Double_t > fEff10s
std::vector< Double_t > GetEff30Values()
#define ClassDef(name, id)
Definition: Rtypes.h:254
std::vector< Double_t > GetEffAreaValues()
Base class for all machine learning algorithms.
Definition: Envelope.h:55
std::vector< Double_t > fTrainEff01s
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEff01s
unsigned int UInt_t
Definition: RtypesCore.h:42
std::vector< Double_t > GetTrainEff01Values()
Float_t GetROCStandardDeviation() const
std::vector< Double_t > fTrainEff30s
The Canvas class.
Definition: TCanvas.h:41
std::vector< Double_t > GetTrainEff30Values()
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
std::vector< Double_t > fEffAreas
std::vector< Double_t > fSeps
Abstract ClassifierFactory template that handles arbitrary types.
std::map< UInt_t, Float_t > fROCs
TCanvas * Draw(const TString name="CrossValidation") const
std::vector< Double_t > fEff30s
std::vector< Double_t > GetTrainEff10Values()
const Bool_t kTRUE
Definition: Rtypes.h:91
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > GetSepValues()
char name[80]
Definition: TGX11.cxx:109