Logo ROOT   6.07/09
Reference Guide
HyperParameterOptimisation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson.
3 
5 
6 #include "TMVA/Configurable.h"
7 #include "TMVA/CrossValidation.h"
8 #include "TMVA/DataSet.h"
9 #include "TMVA/Event.h"
10 #include "TMVA/MethodBase.h"
12 #include "TMVA/Types.h"
13 
14 #include "TGraph.h"
15 #include "TMultiGraph.h"
16 #include "TString.h"
17 #include "TSystem.h"
18 
19 #include <iostream>
20 #include <vector>
21 
22 //HyperParameterOptimisationResult stuff
23 // ClassImp(TMVA::HyperParameterOptimisationResult)
24 
26 {
27  fROCCurves = new TMultiGraph("ROCCurves","ROCCurves");
28 
29 }
30 
32 {
33  if(fROCCurves) delete fROCCurves;
34 }
35 
37 {
38 
39  return fROCCurves;
40 }
41 
42 //HyperParameterOptimisation class stuff
43 // ClassImp(TMVA::HyperParameterOptimisation)//serialization is not support yet in so many class TMVA
44 
45 /*TMVA::HyperParameterOptimisation::HyperParameterOptimisation():Configurable( ),
46 fDataLoader(0)
47 {
48  fClassifier=new TMVA::Factory("CrossValidation","!V:Silent:Color:DrawProgressBar:AnalysisType=Classification");
49  }*/
50 
51 
53 fDataLoader(loader),
54 fFomType(fomType),
55 fFitType(fitType)
56 {
57  fClassifier=new TMVA::Factory("CrossValidation","!V:Silent:Color:DrawProgressBar:AnalysisType=Classification");
58 }
59 
61 {
62  if(fClassifier) delete fClassifier;
63 }
64 
66 {
67  //TODO by Thomas Stevenson
68  //
69 
71 
72  fDataLoader->MakeKFoldDataSet(NumFolds);
73 
74  for(Int_t i = 0; i < NumFolds; ++i){
75 
76  TString foldTitle = methodTitle;
77  foldTitle += "_opt";
78  foldTitle += i+1;
79 
81 
83 
84  fClassifier->BookMethod(fDataLoader, theMethodName, methodTitle, theOption);
85 
86  TMVA::MethodBase * smethod = dynamic_cast<TMVA::MethodBase*>(fClassifier->fMethodsMap[fDataLoader->GetName()][0][0]);
87 
88  result->fFoldParameters.push_back(smethod->OptimizeTuningParameters(fFomType,fFitType));
89 
90  //smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
92 
94 
95  fClassifier->fMethodsMap.clear();
96 
97  }
98 
99  for(UInt_t j=0; j<result->fFoldParameters.size(); ++j){
100  std::cout << "===========================================================" << std::endl;
101  std::cout << "Optimisation for " << theMethodName << " fold " << j+1 << std::endl;
102 
103  std::map<TString,Double_t>::iterator iter;
104  for(iter=result->fFoldParameters.at(j).begin(); iter!=result->fFoldParameters.at(j).end(); iter++){
105  std::cout << iter->first << " " << iter->second << std::endl;
106  }
107  }
108 
109  return result;
110 }
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Definition: Factory.cxx:337
virtual std::map< TString, Double_t > OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA")
call the Optimzier with the set of paremeters and ranges that are meant to be tuned.
Definition: MethodBase.cxx:617
DataSet * Data() const
Definition: MethodBase.h:405
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:37
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
void MakeKFoldDataSet(UInt_t numberFolds, bool validationSet=false)
Definition: DataLoader.cxx:610
bool Bool_t
Definition: RtypesCore.h:59
const TString & GetMethodName() const
Definition: MethodBase.h:327
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:388
void PrepareFoldDataSet(UInt_t foldNumber, Types::ETreeType tt)
Definition: DataLoader.cxx:661
std::map< TString, MVector * > fMethodsMap
Definition: Factory.h:91
HyperParameterOptimisation(DataLoader *loader, TString fomType="Separation", TString fitType="Minuit")
HyperParameterOptimisationResult * Optimise(TString theMethodName, TString methodTitle, TString theOption="", int NumFolds=5)
void DeleteResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
delete the results stored for this particulary Method instance (here appareantly called resultsName i...
Definition: DataSet.cxx:337
unsigned int UInt_t
Definition: RtypesCore.h:42
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:51
std::vector< std::map< TString, Double_t > > fFoldParameters
Mother of all ROOT objects.
Definition: TObject.h:44
void DeleteAllMethods(void)
delete methods
Definition: Factory.cxx:311
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
double result[121]
const Bool_t kTRUE
Definition: Rtypes.h:91