Logo ROOT  
Reference Guide
HyperParameterOptimisation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Omar Zapata, Thomas James Stevenson.
3
5
6#include "TMVA/Configurable.h"
7#include "TMVA/CvSplit.h"
8#include "TMVA/DataSet.h"
9#include "TMVA/Event.h"
10#include "TMVA/MethodBase.h"
12#include "TMVA/Types.h"
13
14#include "TGraph.h"
15#include "TMultiGraph.h"
16#include "TString.h"
17#include "TSystem.h"
18
19#include <iostream>
20#include <memory>
21#include <vector>
22
23/*! \class TMVA::HyperParameterOptimisationResult
24\ingroup TMVA
25
26*/
27
28/*! \class TMVA::HyperParameterOptimisation
29\ingroup TMVA
30
31*/
32
33//_______________________________________________________________________
35 : fROCAVG(0.0), fROCCurves(std::make_shared<TMultiGraph>())
36{
37}
38
39//_______________________________________________________________________
41{
42}
43
44//_______________________________________________________________________
46{
47
48 return fROCCurves.get();
49}
50
51//_______________________________________________________________________
53{
56
57 MsgLogger fLogger("HyperParameterOptimisation");
58
59 for(UInt_t j=0; j<fFoldParameters.size(); ++j) {
60 fLogger<<kHEADER<< "===========================================================" << Endl;
61 fLogger<<kINFO<< "Optimisation for " << fMethodName << " fold " << j+1 << Endl;
62
63 for(auto &it : fFoldParameters.at(j)) {
64 fLogger<<kINFO<< it.first << " " << it.second << Endl;
65 }
66 }
67
69
70}
71
72//_______________________________________________________________________
74 fFomType("Separation"),
75 fFitType("Minuit"),
76 fNumFolds(5),
77 fResults(),
78 fClassifier(new TMVA::Factory("HyperParameterOptimisation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
79{
81}
82
83//_______________________________________________________________________
85{
86 fClassifier=nullptr;
87}
88
89//_______________________________________________________________________
91{
92 fNumFolds = i;
93 // fDataLoader->MakeKFoldDataSet(fNumFolds);
94 fFoldStatus = kFALSE;
95}
96
97//_______________________________________________________________________
99{
100 for (auto &meth : fMethods) {
101 TString methodName = meth.GetValue<TString>("MethodName");
102 TString methodTitle = meth.GetValue<TString>("MethodTitle");
103 TString methodOptions = meth.GetValue<TString>("MethodOptions");
104
105 CvSplitKFolds split{fNumFolds, "", kFALSE, 0};
106 if (!fFoldStatus) {
107 fDataLoader->MakeKFoldDataSet(split);
108 fFoldStatus = kTRUE;
109 }
110 fResults.fMethodName = methodName;
111
112 for (UInt_t i = 0; i < fNumFolds; ++i) {
113 TString foldTitle = methodTitle;
114 foldTitle += "_opt";
115 foldTitle += i + 1;
116
118 fDataLoader->PrepareFoldDataSet(split, i, TMVA::Types::kTraining);
119
120 auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
121
122 auto params = smethod->OptimizeTuningParameters(fFomType, fFitType);
123 fResults.fFoldParameters.push_back(params);
124
125 smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
126
127 fClassifier->DeleteAllMethods();
128
129 fClassifier->fMethodsMap.clear();
130 }
131 }
132}
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kTRUE
Definition: RtypesCore.h:87
void SetSilent(Bool_t s)
Definition: Config.h:65
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Definition: Envelope.h:47
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Definition: Event.cxx:392
This is the main MVA steering class.
Definition: Factory.h:81
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
HyperParameterOptimisation(DataLoader *dataloader)
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
static void EnableOutput()
Definition: MsgLogger.cxx:75
@ kClassification
Definition: Types.h:128
@ kTraining
Definition: Types.h:144
A TMultiGraph is a collection of TGraph (or derived) objects.
Definition: TMultiGraph.h:35
Basic string class.
Definition: TString.h:131
create variable transformations
Config & gConfig()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158