Logo ROOT   6.07/09
Reference Guide
CrossValidation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson.
3 
4 #include "TMVA/CrossValidation.h"
5 
6 #include "TMVA/Config.h"
7 #include "TMVA/DataSet.h"
8 #include "TMVA/Event.h"
9 #include "TMVA/MethodBase.h"
10 #include "TMVA/MsgLogger.h"
12 #include "TMVA/tmvaglob.h"
13 #include "TMVA/Types.h"
14 
15 #include "TSystem.h"
16 #include "TAxis.h"
17 #include "TCanvas.h"
18 #include "TGraph.h"
19 
20 #include <iostream>
21 #include <memory>
22 
24 {
25 }
26 
28 {
29  fROCs=obj.fROCs;
30  fROCCurves = obj.fROCCurves;
31 }
32 
33 
35 {
36  return fROCCurves.get();
37 }
38 
40 {
41  Float_t avg=0;
42  for(auto &roc:fROCs) avg+=roc.second;
43  return avg/fROCs.size();
44 }
45 
46 
48 {
51 
52  MsgLogger fLogger("CrossValidation");
53  for(auto &item:fROCs)
54  fLogger<<kINFO<<Form("Fold %i ROC-Int : %f",item.first,item.second)<<std::endl;
55 
56  fLogger<<kINFO<<Form("Average ROC-Int : %f",GetROCAverage())<<Endl;
57 
59 
60 }
61 
62 
64 {
65  TCanvas *c=new TCanvas(name.Data());
66  fROCCurves->Draw("AL");
67  fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
68  fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
69  Float_t adjust=1+fROCs.size()*0.01;
70  c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
71  c->SetTitle("Cross Validation ROC Curves");
72  c->Draw();
73  return c;
74 }
75 
76 TMVA::CrossValidation::CrossValidation(TMVA::DataLoader *dataloader):TMVA::Envelope("CrossValidation",dataloader),
77 fNumFolds(5),fClassifier(new TMVA::Factory("CrossValidation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
78 {
80 }
81 
83 {
84  fClassifier=nullptr;
85 }
86 
88 {
89  fNumFolds=i;
90  fDataLoader->MakeKFoldDataSet(fNumFolds);
92 }
93 
94 
96 {
97  TString methodName = fMethod.GetValue<TString>("MethodName");
98  TString methodTitle = fMethod.GetValue<TString>("MethodTitle");
99  TString methodOptions = fMethod.GetValue<TString>("MethodOptions");
100  if(!fFoldStatus)
101  {
102  fDataLoader->MakeKFoldDataSet(fNumFolds);
104  }
105 
106  for(UInt_t i = 0; i < fNumFolds; ++i){
107  TString foldTitle = methodTitle;
108  foldTitle += "_fold";
109  foldTitle += i+1;
110 
111  fDataLoader->PrepareFoldDataSet(i, TMVA::Types::kTesting);
112 
113 
114  auto smethod=fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
115 
117  smethod->TrainMethod();
118 
120  smethod->AddOutput(Types::kTesting, smethod->GetAnalysisType());
121  smethod->TestClassification();
122 
123 
124  fResults.fROCs[i]=fClassifier->GetROCIntegral(fDataLoader->GetName(),methodTitle);
125 
126  auto gr=fClassifier->GetROCCurve(fDataLoader->GetName(), methodTitle, true);
127 
128  gr->SetLineColor(i+1);
129  gr->SetLineWidth(2);
130  gr->SetTitle(foldTitle.Data());
131 
132  fResults.fROCCurves->Add(gr);
133 
134  fResults.fSigs.push_back(smethod->GetSignificance());
135  fResults.fSeps.push_back(smethod->GetSeparation());
136 
137  Double_t err;
138  fResults.fEff01s.push_back(smethod->GetEfficiency("Efficiency:0.01",Types::kTesting, err));
139  fResults.fEff10s.push_back(smethod->GetEfficiency("Efficiency:0.10",Types::kTesting,err));
140  fResults.fEff30s.push_back(smethod->GetEfficiency("Efficiency:0.30",Types::kTesting,err));
141  fResults.fEffAreas.push_back(smethod->GetEfficiency("" ,Types::kTesting,err));
142  fResults.fTrainEff01s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.01"));
143  fResults.fTrainEff10s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.10"));
144  fResults.fTrainEff30s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.30"));
145 
146  smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
147  smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
148 
149  fClassifier->DeleteAllMethods();
150  fClassifier->fMethodsMap.clear();
151  }
154  Log()<<kINFO<<"Evaluation done."<<Endl;
156 
157 
158 }
Config & gConfig()
Definition: Config.cxx:43
std::vector< Double_t > fSigs
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
Definition: TAttLine.h:49
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
float Float_t
Definition: RtypesCore.h:53
void SetTitle(const char *title="")
Set canvas title.
Definition: TCanvas.cxx:1917
return c
T GetValue(const TString &key)
Definition: OptionMap.h:152
TCanvas * Draw(const TString name="CrossValidation") const
CrossValidationResult fResults
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:37
Basic string class.
Definition: TString.h:137
bool Bool_t
Definition: RtypesCore.h:59
std::unique_ptr< Factory > fClassifier
const Bool_t kFALSE
Definition: Rtypes.h:92
virtual void SetTitle(const char *title="")
Set graph title.
Definition: TGraph.cxx:2176
std::vector< Double_t > fEff10s
const char * Data() const
Definition: TString.h:349
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:388
OptionMap fMethod
Definition: Envelope.h:58
void SetNumFolds(UInt_t i)
Base class for all machine learning algorithms.
Definition: Envelope.h:55
std::vector< Double_t > fTrainEff01s
std::vector< Double_t > fTrainEff10s
virtual void SetLineColor(Color_t lcolor)
Set the line color.
Definition: TAttLine.h:46
std::vector< Double_t > fEff01s
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
std::vector< Double_t > fTrainEff30s
TGraphErrors * gr
Definition: legend1.C:25
CrossValidation(DataLoader *loader)
virtual void Evaluate()
Virtual method to be implmented with your algorithm.
The Canvas class.
Definition: TCanvas.h:41
double Double_t
Definition: RtypesCore.h:55
std::shared_ptr< DataLoader > fDataLoader
Definition: Envelope.h:59
MsgLogger & Log() const
Definition: Configurable.h:128
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
std::vector< Double_t > fEffAreas
virtual void Draw(Option_t *option="")
Draw a canvas.
Definition: TCanvas.cxx:795
std::vector< Double_t > fSeps
Abstract ClassifierFactory template that handles arbitrary types.
void SetSilent(Bool_t s)
Definition: Config.h:64
std::map< UInt_t, Float_t > fROCs
std::vector< Double_t > fEff30s
virtual TLegend * BuildLegend(Double_t x1=0.5, Double_t y1=0.67, Double_t x2=0.88, Double_t y2=0.88, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
Definition: TPad.cxx:426
static void EnableOutput()
Definition: MsgLogger.cxx:70
const Bool_t kTRUE
Definition: Rtypes.h:91
std::shared_ptr< TMultiGraph > fROCCurves
char name[80]
Definition: TGX11.cxx:109