Logo ROOT   6.08/07
Reference Guide
HyperParameterOptimisation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson.
3 
5 
6 #include "TMVA/Configurable.h"
7 #include "TMVA/DataSet.h"
8 #include "TMVA/Event.h"
9 #include "TMVA/MethodBase.h"
11 #include "TMVA/Types.h"
12 
13 #include "TGraph.h"
14 #include "TMultiGraph.h"
15 #include "TString.h"
16 #include "TSystem.h"
17 
18 #include <iostream>
19 #include <vector>
20 
22 {
23 }
24 
26 {
27  fROCCurves=nullptr;
28 }
29 
31 {
32 
33  return fROCCurves.get();
34 }
35 
37 {
40 
41  MsgLogger fLogger("HyperParameterOptimisation");
42 
43  for(UInt_t j=0; j<fFoldParameters.size(); ++j) {
44  fLogger<<kHEADER<< "===========================================================" << Endl;
45  fLogger<<kINFO<< "Optimisation for " << fMethodName << " fold " << j+1 << Endl;
46 
47  for(auto &it : fFoldParameters.at(j)) {
48  fLogger<<kINFO<< it.first << " " << it.second << Endl;
49  }
50  }
51 
53 
54 }
55 
57  fFomType("Separation"),
58  fFitType("Minuit"),
59  fNumFolds(5),
60  fResults(),
61  fClassifier(new TMVA::Factory("HyperParameterOptimisation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
62 {
64 }
65 
67 {
68  fClassifier=nullptr;
69 }
70 
72 {
73  fNumFolds=i;
74  fDataLoader->MakeKFoldDataSet(fNumFolds);
76 }
77 
79 {
80  TString methodName = fMethod.GetValue<TString>("MethodName");
81  TString methodTitle = fMethod.GetValue<TString>("MethodTitle");
82  TString methodOptions = fMethod.GetValue<TString>("MethodOptions");
83 
84  if(!fFoldStatus)
85  {
86  fDataLoader->MakeKFoldDataSet(fNumFolds);
88  }
89  fResults.fMethodName = methodName;
90 
91  for(UInt_t i = 0; i < fNumFolds; ++i) {
92 
93  TString foldTitle = methodTitle;
94  foldTitle += "_opt";
95  foldTitle += i+1;
96 
98  fDataLoader->PrepareFoldDataSet(i, TMVA::Types::kTraining);
99 
100  auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
101 
102  auto params=smethod->OptimizeTuningParameters(fFomType,fFitType);
103  fResults.fFoldParameters.push_back(params);
104 
105  smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
106 
107  fClassifier->DeleteAllMethods();
108 
109  fClassifier->fMethodsMap.clear();
110 
111  }
112 
113 }
Config & gConfig()
Definition: Config.cxx:43
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
T GetValue(const TString &key)
Definition: OptionMap.h:152
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:37
Basic string class.
Definition: TString.h:137
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
HyperParameterOptimisationResult fResults
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:388
OptionMap fMethod
Definition: Envelope.h:58
Base class for all machine learning algorithms.
Definition: Envelope.h:55
virtual void Evaluate()
Virtual method to be implmented with your algorithm.
HyperParameterOptimisation(DataLoader *dataloader)
unsigned int UInt_t
Definition: RtypesCore.h:42
std::shared_ptr< DataLoader > fDataLoader
Definition: Envelope.h:59
std::vector< std::map< TString, Double_t > > fFoldParameters
Abstract ClassifierFactory template that handles arbitrary types.
void SetSilent(Bool_t s)
Definition: Config.h:64
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
static void EnableOutput()
Definition: MsgLogger.cxx:70
const Bool_t kTRUE
Definition: Rtypes.h:91