Logo ROOT   6.16/01
Reference Guide
MethodCrossValidation.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Kim Albertsson
3
4/*************************************************************************
5 * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/*! \class TMVA::MethodCrossValidation
13\ingroup TMVA
14*/
16
18#include "TMVA/Config.h"
19#include "TMVA/CvSplit.h"
20#include "TMVA/MethodCategory.h"
21#include "TMVA/Tools.h"
22#include "TMVA/Types.h"
23
24#include "TSystem.h"
25
26REGISTER_METHOD(CrossValidation)
27
29
30////////////////////////////////////////////////////////////////////////////////
31///
32
34 DataSetInfo &theData, const TString &theOption)
35 : TMVA::MethodBase(jobName, Types::kCrossValidation, methodTitle, theData, theOption), fSplitExpr(nullptr)
36{
37}
38
39////////////////////////////////////////////////////////////////////////////////
40
42 : TMVA::MethodBase(Types::kCrossValidation, theData, theWeightFile), fSplitExpr(nullptr)
43{
44}
45
46////////////////////////////////////////////////////////////////////////////////
47/// Destructor.
48///
49
51
52////////////////////////////////////////////////////////////////////////////////
53
55{
56 DeclareOptionRef(fEncapsulatedMethodName, "EncapsulatedMethodName", "");
57 DeclareOptionRef(fEncapsulatedMethodTypeName, "EncapsulatedMethodTypeName", "");
58 DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
59 DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
60 "Combines output from contained methods. If None, no combination is performed. (default None)");
61 AddPreDefVal(TString("None"));
62 AddPreDefVal(TString("Avg"));
63 DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
64}
65
66////////////////////////////////////////////////////////////////////////////////
67/// Options that are used ONLY for the READER to ensure backward compatibility.
68
70{
72}
73
74////////////////////////////////////////////////////////////////////////////////
75/// The option string is decoded, for available options see "DeclareOptions".
76
78{
79 Log() << kDEBUG << "ProcessOptions -- fNumFolds: " << fNumFolds << Endl;
80 Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodName: " << fEncapsulatedMethodName << Endl;
81 Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodTypeName: " << fEncapsulatedMethodTypeName << Endl;
82
83 if (fSplitExprString != TString("")) {
84 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
85 }
86
87 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
88 TString weightfile = GetWeightFileNameForFold(iFold);
89
90 Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
91 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
92 fEncapsulatedMethods.push_back(fold_method);
93 }
94}
95
96////////////////////////////////////////////////////////////////////////////////
97/// Common initialisation with defaults for the Method.
98
100{
101 fMulticlassValues = std::vector<Float_t>(DataInfo().GetNClasses());
102 fRegressionValues = std::vector<Float_t>(DataInfo().GetNTargets());
103}
104
105////////////////////////////////////////////////////////////////////////////////
106/// Reset the method, as if it had just been instantiated (forget all training etc.).
107
109
110////////////////////////////////////////////////////////////////////////////////
111/// \brief Returns filename of weight file for a given fold.
112/// \param[in] iFold Ordinal of the fold. Range: 0 to NumFolds exclusive.
113///
115{
116 if (iFold >= fNumFolds) {
117 Log() << kFATAL << iFold << " out of range. "
118 << "Should be < " << fNumFolds << "." << Endl;
119 }
120
121 TString foldStr = Form("fold%i", iFold + 1);
122 TString fileDir = gSystem->DirName(GetWeightFileName());
123 TString weightfile = fileDir + "/" + fJobName + "_" + fEncapsulatedMethodName + "_" + foldStr + ".weights.xml";
124
125 return weightfile;
126}
127
128////////////////////////////////////////////////////////////////////////////////
129/// Call the Optimizer with the set of parameters and ranges that
130/// are meant to be tuned.
131
132// std::map<TString,Double_t> TMVA::MethodCrossValidation::OptimizeTuningParameters(TString fomType, TString fitType)
133// {
134// }
135
136////////////////////////////////////////////////////////////////////////////////
137/// Set the tuning parameters according to the argument.
138
139// void TMVA::MethodCrossValidation::SetTuneParameters(std::map<TString,Double_t> tuneParameters)
140// {
141// }
142
143////////////////////////////////////////////////////////////////////////////////
144/// training.
145
147
148////////////////////////////////////////////////////////////////////////////////
149/// \brief Reads in a weight file an instantiates the corresponding method
150/// \param[in] methodTypeName Canonical name of the method type. E.g. `"BDT"`
151/// for Boosted Decision Trees.
152/// \param[in] weightfile File to read method parameters from
155{
156 TMVA::MethodBase *m = dynamic_cast<MethodBase *>(
157 ClassifierFactory::Instance().Create(std::string(methodTypeName.Data()), DataInfo(), weightfile));
158
159 if (m->GetMethodType() == Types::kCategory) {
160 Log() << kFATAL << "MethodCategory not supported for the moment." << Endl;
161 }
162
163 TString fileDir = TString(DataInfo().GetName()) + "/" + gConfig().GetIONames().fWeightFileDir;
164 m->SetWeightFileDir(fileDir);
165 // m->SetModelPersistence(fModelPersistence);
166 // m->SetSilentFile(IsSilentFile());
167 m->SetAnalysisType(fAnalysisType);
168 m->SetupMethod();
169 m->ReadStateFromFile();
170 // m->SetTestvarName(testvarName);
171
172 return m;
173}
174
175////////////////////////////////////////////////////////////////////////////////
176/// Write weights to XML.
177
179{
180 void *wght = gTools().AddChild(parent, "Weights");
181
182 gTools().AddAttr(wght, "JobName", fJobName);
183 gTools().AddAttr(wght, "SplitExpr", fSplitExprString);
184 gTools().AddAttr(wght, "NumFolds", fNumFolds);
185 gTools().AddAttr(wght, "EncapsulatedMethodName", fEncapsulatedMethodName);
186 gTools().AddAttr(wght, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
187 gTools().AddAttr(wght, "OutputEnsembling", fOutputEnsembling);
188
189 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
190 TString weightfile = GetWeightFileNameForFold(iFold);
191
192 // TODO: Add a swithch in options for using either split files or only one.
193 // TODO: This would store the method inside MethodCrossValidation
194 // Another option is to store the folds as separate files.
195 // // Retrieve encap. method for fold n
196 // MethodBase * method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
197 //
198 // // Serialise encapsulated method for fold n
199 // void* foldNode = gTools().AddChild(parent, foldStr);
200 // method->WriteStateToXML(foldNode);
201 }
202}
203
204////////////////////////////////////////////////////////////////////////////////
205/// Reads from the xml file.
206///
207
209{
210 gTools().ReadAttr(parent, "JobName", fJobName);
211 gTools().ReadAttr(parent, "SplitExpr", fSplitExprString);
212 gTools().ReadAttr(parent, "NumFolds", fNumFolds);
213 gTools().ReadAttr(parent, "EncapsulatedMethodName", fEncapsulatedMethodName);
214 gTools().ReadAttr(parent, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
215 gTools().ReadAttr(parent, "OutputEnsembling", fOutputEnsembling);
216
217 // Read in methods for all folds
218 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
219 TString weightfile = GetWeightFileNameForFold(iFold);
220
221 Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
222 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
223 fEncapsulatedMethods.push_back(fold_method);
224 }
225
226 // SplitExpr
227 if (fSplitExprString != TString("")) {
228 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
229 }
230}
231
232////////////////////////////////////////////////////////////////////////////////
233/// Read the weights
234///
235
237{
238 Log() << kFATAL << "CrossValidation currently supports only reading from XML." << Endl;
239}
240
241////////////////////////////////////////////////////////////////////////////////
242///
243
245{
246 const Event *ev = GetEvent();
247
248 if (fOutputEnsembling == "None") {
249 if (fSplitExpr != nullptr) {
250 // K-folds with a deterministic split
251 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
252 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
253 } else {
254 // K-folds with a random split was used
255 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
256 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
257 }
258 } else if (fOutputEnsembling == "Avg") {
259 Double_t val = 0.0;
260 for (auto &m : fEncapsulatedMethods) {
261 val += m->GetMvaValue(err, errUpper);
262 }
263 return val / fEncapsulatedMethods.size();
264 } else {
265 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
266 return 0; // Cannot happen
267 }
268}
269
270////////////////////////////////////////////////////////////////////////////////
271/// Get the multiclass MVA response.
272
274{
275 const Event *ev = GetEvent();
276
277 if (fOutputEnsembling == "None") {
278 if (fSplitExpr != nullptr) {
279 // K-folds with a deterministic split
280 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
281 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
282 } else {
283 // K-folds with a random split was used
284 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
285 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
286 }
287 } else if (fOutputEnsembling == "Avg") {
288
289 for (auto &e : fMulticlassValues) {
290 e = 0;
291 }
292
293 for (auto &m : fEncapsulatedMethods) {
294 auto methodValues = m->GetMulticlassValues();
295 for (size_t i = 0; i < methodValues.size(); ++i) {
296 fMulticlassValues[i] += methodValues[i];
297 }
298 }
299
300 for (auto &e : fMulticlassValues) {
301 e /= fEncapsulatedMethods.size();
302 }
303
304 return fMulticlassValues;
305
306 } else {
307 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
308 return fMulticlassValues; // Cannot happen
309 }
310}
311
312////////////////////////////////////////////////////////////////////////////////
313/// Get the regression value generated by the containing methods.
314
316{
317 const Event *ev = GetEvent();
318
319 if (fOutputEnsembling == "None") {
320 if (fSplitExpr != nullptr) {
321 // K-folds with a deterministic split
322 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
323 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
324 } else {
325 // K-folds with a random split was used
326 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
327 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
328 }
329 } else if (fOutputEnsembling == "Avg") {
330
331 for (auto &e : fRegressionValues) {
332 e = 0;
333 }
334
335 for (auto &m : fEncapsulatedMethods) {
336 auto methodValues = m->GetRegressionValues();
337 for (size_t i = 0; i < methodValues.size(); ++i) {
338 fRegressionValues[i] += methodValues[i];
339 }
340 }
341
342 for (auto &e : fRegressionValues) {
343 e /= fEncapsulatedMethods.size();
344 }
345
346 return fRegressionValues;
347
348 } else {
349 Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
350 return fRegressionValues; // Cannot happen
351 }
352}
353
354////////////////////////////////////////////////////////////////////////////////
355///
356
358{
359 // // Used for evaluation, which is outside the life time of MethodCrossEval.
360 // Log() << kFATAL << "Method CrossValidation should not be created manually,"
361 // " only as part of using TMVA::Reader." << Endl;
362 // return;
363}
364
365////////////////////////////////////////////////////////////////////////////////
366///
367
369{
370 Log() << kWARNING
371 << "Method CrossValidation should not be created manually,"
372 " only as part of using TMVA::Reader."
373 << Endl;
374}
375
376////////////////////////////////////////////////////////////////////////////////
377///
378
380{
381 return nullptr;
382}
383
384////////////////////////////////////////////////////////////////////////////////
385
387 UInt_t /*numberTargets*/)
388{
389 return kTRUE;
390 // if (fEncapsulatedMethods.size() == 0) {return kFALSE;}
391 // if (fEncapsulatedMethods.at(0) == nullptr) {return kFALSE;}
392 // return fEncapsulatedMethods.at(0)->HasAnalysisType(type, numberClasses, numberTargets);
393}
394
395////////////////////////////////////////////////////////////////////////////////
396/// Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
397
398void TMVA::MethodCrossValidation::MakeClassSpecific(std::ostream & /*fout*/, const TString & /*className*/) const
399{
400 Log() << kWARNING << "MakeClassSpecific not implemented for CrossValidation" << Endl;
401}
402
403////////////////////////////////////////////////////////////////////////////////
404/// Specific class header.
405
406void TMVA::MethodCrossValidation::MakeClassSpecificHeader(std::ostream & /*fout*/, const TString & /*className*/) const
407{
408 Log() << kWARNING << "MakeClassSpecificHeader not implemented for CrossValidation" << Endl;
409}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition: RSha256.hxx:103
unsigned int UInt_t
Definition: RtypesCore.h:42
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:363
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition: TSystem.h:540
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
TString fWeightFileDir
Definition: Config.h:112
IONames & GetIONames()
Definition: Config.h:90
Class that contains all the data information.
Definition: DataSetInfo.h:60
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:601
friend class MethodCrossValidation
Definition: MethodBase.h:115
void MakeClassSpecific(std::ostream &, const TString &) const
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
void AddWeightsXMLTo(void *parent) const
Write weights to XML.
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
void Init(void)
Common initialisation with defaults for the Method.
void MakeClassSpecificHeader(std::ostream &, const TString &) const
Specific class header.
void Reset(void)
Reset the method, as if it had just been instantiated (forget all training etc.).
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
const std::vector< Float_t > & GetRegressionValues()
Get the regression value generated by the containing methods.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
void Train(void)
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
const std::vector< Float_t > & GetMulticlassValues()
Get the multiclass MVA response.
void ReadWeightsFromXML(void *parent)
Reads from the xml file.
void DeclareCompatibilityOptions()
Options that are used ONLY for the READER to ensure backward compatibility.
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file dummy implementation here --------------—
virtual ~MethodCrossValidation(void)
Destructor.
void ReadWeightsFromStream(std::istream &istr)
Read the weights.
void ProcessOptions()
The option string is decoded, for available options see "DeclareOptions".
Ranking for variables in method (implementation)
Definition: Ranking.h:48
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:355
Singleton class for Global types used by TMVA.
Definition: Types.h:73
@ kCategory
Definition: Types.h:99
EAnalysisType
Definition: Types.h:127
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
virtual const char * DirName(const char *pathname)
Return the directory name in pathname.
Definition: TSystem.cxx:1013
std::string GetName(const std::string &scope_name)
Definition: Cppyy.cxx:146
Abstract ClassifierFactory template that handles arbitrary types.
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:748
auto * m
Definition: textangle.C:8