Logo ROOT   6.07/09
Reference Guide
CrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou. 2016
3 
4 
5 #ifndef ROOT_TMVA_CrossValidation
6 #define ROOT_TMVA_CrossValidation
7 
8 
9 #ifndef ROOT_TString
10 #include "TString.h"
11 #endif
12 
13 #ifndef ROOT_TMultiGraph
14 #include "TMultiGraph.h"
15 #endif
16 
17 #ifndef ROOT_TMVA_IMethod
18 #include "TMVA/IMethod.h"
19 #endif
20 #ifndef ROOT_TMVA_Configurable
21 #include "TMVA/Configurable.h"
22 #endif
23 #ifndef ROOT_TMVA_Types
24 #include "TMVA/Types.h"
25 #endif
26 #ifndef ROOT_TMVA_DataSet
27 #include "TMVA/DataSet.h"
28 #endif
29 #ifndef ROOT_TMVA_Event
30 #include "TMVA/Event.h"
31 #endif
32 #ifndef ROOT_TMVA_Results
33 #include<TMVA/Results.h>
34 #endif
35 
36 #ifndef ROOT_TMVA_Factory
37 #include<TMVA/Factory.h>
38 #endif
39 
40 #ifndef ROOT_TMVA_DataLoader
41 #include<TMVA/DataLoader.h>
42 #endif
43 
44 #ifndef ROOT_TMVA_OptionMap
45 #include<TMVA/OptionMap.h>
46 #endif
47 
48 #ifndef ROOT_TMVA_Envelope
49 #include<TMVA/Envelope.h>
50 #endif
51 
52 namespace TMVA {
53 
55  {
56  friend class CrossValidation;
57  private:
58  std::map<UInt_t,Float_t> fROCs;
59  std::shared_ptr<TMultiGraph> fROCCurves;
60 
61  std::vector<Double_t> fSigs;
62  std::vector<Double_t> fSeps;
63  std::vector<Double_t> fEff01s;
64  std::vector<Double_t> fEff10s;
65  std::vector<Double_t> fEff30s;
66  std::vector<Double_t> fEffAreas;
67  std::vector<Double_t> fTrainEff01s;
68  std::vector<Double_t> fTrainEff10s;
69  std::vector<Double_t> fTrainEff30s;
70 
71  public:
74  ~CrossValidationResult(){fROCCurves=nullptr;}
75 
76 
77  std::map<UInt_t,Float_t> GetROCValues(){return fROCs;}
78  Float_t GetROCAverage() const;
80  void Print() const ;
81 
82  TCanvas* Draw(const TString name="CrossValidation") const;
83 
84 
85  std::vector<Double_t> GetSigValues(){return fSigs;}
86  std::vector<Double_t> GetSepValues(){return fSeps;}
87  std::vector<Double_t> GetEff01Values(){return fEff01s;}
88  std::vector<Double_t> GetEff10Values(){return fEff10s;}
89  std::vector<Double_t> GetEff30Values(){return fEff30s;}
90  std::vector<Double_t> GetEffAreaValues(){return fEffAreas;}
91  std::vector<Double_t> GetTrainEff01Values(){return fTrainEff01s;}
92  std::vector<Double_t> GetTrainEff10Values(){return fTrainEff10s;}
93  std::vector<Double_t> GetTrainEff30Values(){return fTrainEff30s;}
94 
95  };
96 
97 
98  class CrossValidation : public Envelope {
102  public:
103  explicit CrossValidation(DataLoader *loader);
104  ~CrossValidation();
105 
106  void SetNumFolds(UInt_t i);//{fNumFolds=i;}
107  UInt_t GetNumFolds(){return fNumFolds;}
108 
109  virtual void Evaluate();
110 // void EvaluateFold(UInt_t fold);//used in ParallelExecution
111 
112  const CrossValidationResult& GetResults() const {return fResults;}//I need to think about this which is the best way to get the results
113 
114  private:
115  std::unique_ptr<Factory> fClassifier; //!
117  };
118 }
119 
120 
121 #endif
122 
123 
124 
const CrossValidationResult & GetResults() const
std::vector< Double_t > fSigs
float Float_t
Definition: RtypesCore.h:53
std::map< UInt_t, Float_t > GetROCValues()
TCanvas * Draw(const TString name="CrossValidation") const
CrossValidationResult fResults
std::vector< Double_t > GetEff01Values()
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:37
Basic string class.
Definition: TString.h:137
bool Bool_t
Definition: RtypesCore.h:59
std::unique_ptr< Factory > fClassifier
std::vector< Double_t > GetEff10Values()
std::vector< Double_t > GetSigValues()
std::vector< Double_t > fEff10s
std::vector< Double_t > GetEff30Values()
#define ClassDef(name, id)
Definition: Rtypes.h:254
std::vector< Double_t > GetEffAreaValues()
Base class for all machine learning algorithms.
Definition: Envelope.h:55
std::vector< Double_t > fTrainEff01s
std::vector< Double_t > fTrainEff10s
std::vector< Double_t > fEff01s
unsigned int UInt_t
Definition: RtypesCore.h:42
std::vector< Double_t > GetTrainEff01Values()
std::vector< Double_t > fTrainEff30s
The Canvas class.
Definition: TCanvas.h:41
std::vector< Double_t > GetTrainEff30Values()
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
std::vector< Double_t > fEffAreas
std::vector< Double_t > fSeps
Abstract ClassifierFactory template that handles arbitrary types.
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fEff30s
std::vector< Double_t > GetTrainEff10Values()
const Bool_t kTRUE
Definition: Rtypes.h:91
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< Double_t > GetSepValues()
char name[80]
Definition: TGX11.cxx:109