Logo ROOT  
Reference Guide
MethodCrossValidation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /*! \class TMVA::MethodCrossValidation
13 \ingroup TMVA
14 */
16 
17 #include "TMVA/ClassifierFactory.h"
18 #include "TMVA/Config.h"
19 #include "TMVA/CvSplit.h"
20 #include "TMVA/MethodCategory.h"
21 #include "TMVA/Tools.h"
22 #include "TMVA/Types.h"
23 
24 #include "TSystem.h"
25 
26 REGISTER_METHOD(CrossValidation)
27 
29 
30 ////////////////////////////////////////////////////////////////////////////////
31 ///
32 
34  DataSetInfo &theData, const TString &theOption)
35  : TMVA::MethodBase(jobName, Types::kCrossValidation, methodTitle, theData, theOption), fSplitExpr(nullptr)
36 {
37 }
38 
39 ////////////////////////////////////////////////////////////////////////////////
40 
42  : TMVA::MethodBase(Types::kCrossValidation, theData, theWeightFile), fSplitExpr(nullptr)
43 {
44 }
45 
46 ////////////////////////////////////////////////////////////////////////////////
47 /// Destructor.
48 ///
49 
51 
52 ////////////////////////////////////////////////////////////////////////////////
53 
55 {
56  DeclareOptionRef(fEncapsulatedMethodName, "EncapsulatedMethodName", "");
57  DeclareOptionRef(fEncapsulatedMethodTypeName, "EncapsulatedMethodTypeName", "");
58  DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
59  DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
60  "Combines output from contained methods. If None, no combination is performed. (default None)");
61  AddPreDefVal(TString("None"));
62  AddPreDefVal(TString("Avg"));
63  DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
64 }
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 /// Options that are used ONLY for the READER to ensure backward compatibility.
68 
70 {
72 }
73 
74 ////////////////////////////////////////////////////////////////////////////////
75 /// The option string is decoded, for available options see "DeclareOptions".
76 
78 {
79  Log() << kDEBUG << "ProcessOptions -- fNumFolds: " << fNumFolds << Endl;
80  Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodName: " << fEncapsulatedMethodName << Endl;
81  Log() << kDEBUG << "ProcessOptions -- fEncapsulatedMethodTypeName: " << fEncapsulatedMethodTypeName << Endl;
82 
83  if (fSplitExprString != TString("")) {
84  fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
85  }
86 
87  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
88  TString weightfile = GetWeightFileNameForFold(iFold);
89 
90  Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
91  MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
92  fEncapsulatedMethods.push_back(fold_method);
93  }
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 /// Common initialisation with defaults for the Method.
98 
100 {
101  fMulticlassValues = std::vector<Float_t>(DataInfo().GetNClasses());
102  fRegressionValues = std::vector<Float_t>(DataInfo().GetNTargets());
103 }
104 
105 ////////////////////////////////////////////////////////////////////////////////
106 /// Reset the method, as if it had just been instantiated (forget all training etc.).
107 
109 
110 ////////////////////////////////////////////////////////////////////////////////
111 /// \brief Returns filename of weight file for a given fold.
112 /// \param[in] iFold Ordinal of the fold. Range: 0 to NumFolds exclusive.
113 ///
115 {
116  if (iFold >= fNumFolds) {
117  Log() << kFATAL << iFold << " out of range. "
118  << "Should be < " << fNumFolds << "." << Endl;
119  }
120 
121  TString foldStr = Form("fold%i", iFold + 1);
122  TString fileDir = gSystem->GetDirName(GetWeightFileName());
123  TString weightfile = fileDir + "/" + fJobName + "_" + fEncapsulatedMethodName + "_" + foldStr + ".weights.xml";
124 
125  return weightfile;
126 }
127 
128 ////////////////////////////////////////////////////////////////////////////////
129 /// Call the Optimizer with the set of parameters and ranges that
130 /// are meant to be tuned.
131 
132 // std::map<TString,Double_t> TMVA::MethodCrossValidation::OptimizeTuningParameters(TString fomType, TString fitType)
133 // {
134 // }
135 
136 ////////////////////////////////////////////////////////////////////////////////
137 /// Set the tuning parameters according to the argument.
138 
139 // void TMVA::MethodCrossValidation::SetTuneParameters(std::map<TString,Double_t> tuneParameters)
140 // {
141 // }
142 
143 ////////////////////////////////////////////////////////////////////////////////
144 /// training.
145 
147 
148 ////////////////////////////////////////////////////////////////////////////////
149 /// \brief Reads in a weight file an instantiates the corresponding method
150 /// \param[in] methodTypeName Canonical name of the method type. E.g. `"BDT"`
151 /// for Boosted Decision Trees.
152 /// \param[in] weightfile File to read method parameters from
155 {
156  TMVA::MethodBase *m = dynamic_cast<MethodBase *>(
157  ClassifierFactory::Instance().Create(std::string(methodTypeName.Data()), DataInfo(), weightfile));
158 
159  if (m->GetMethodType() == Types::kCategory) {
160  Log() << kFATAL << "MethodCategory not supported for the moment." << Endl;
161  }
162 
163  TString fileDir = TString(DataInfo().GetName()) + "/" + gConfig().GetIONames().fWeightFileDir;
164  m->SetWeightFileDir(fileDir);
165  // m->SetModelPersistence(fModelPersistence);
166  // m->SetSilentFile(IsSilentFile());
167  m->SetAnalysisType(fAnalysisType);
168  m->SetupMethod();
169  m->ReadStateFromFile();
170  // m->SetTestvarName(testvarName);
171 
172  return m;
173 }
174 
175 ////////////////////////////////////////////////////////////////////////////////
176 /// Write weights to XML.
177 
179 {
180  void *wght = gTools().AddChild(parent, "Weights");
181 
182  gTools().AddAttr(wght, "JobName", fJobName);
183  gTools().AddAttr(wght, "SplitExpr", fSplitExprString);
184  gTools().AddAttr(wght, "NumFolds", fNumFolds);
185  gTools().AddAttr(wght, "EncapsulatedMethodName", fEncapsulatedMethodName);
186  gTools().AddAttr(wght, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
187  gTools().AddAttr(wght, "OutputEnsembling", fOutputEnsembling);
188 
189  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
190  TString weightfile = GetWeightFileNameForFold(iFold);
191 
192  // TODO: Add a swithch in options for using either split files or only one.
193  // TODO: This would store the method inside MethodCrossValidation
194  // Another option is to store the folds as separate files.
195  // // Retrieve encap. method for fold n
196  // MethodBase * method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
197  //
198  // // Serialise encapsulated method for fold n
199  // void* foldNode = gTools().AddChild(parent, foldStr);
200  // method->WriteStateToXML(foldNode);
201  }
202 }
203 
204 ////////////////////////////////////////////////////////////////////////////////
205 /// Reads from the xml file.
206 ///
207 
209 {
210  gTools().ReadAttr(parent, "JobName", fJobName);
211  gTools().ReadAttr(parent, "SplitExpr", fSplitExprString);
212  gTools().ReadAttr(parent, "NumFolds", fNumFolds);
213  gTools().ReadAttr(parent, "EncapsulatedMethodName", fEncapsulatedMethodName);
214  gTools().ReadAttr(parent, "EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
215  gTools().ReadAttr(parent, "OutputEnsembling", fOutputEnsembling);
216 
217  // Read in methods for all folds
218  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
219  TString weightfile = GetWeightFileNameForFold(iFold);
220 
221  Log() << kINFO << "Reading weightfile: " << weightfile << Endl;
222  MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
223  fEncapsulatedMethods.push_back(fold_method);
224  }
225 
226  // SplitExpr
227  if (fSplitExprString != TString("")) {
228  fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
229  } else {
230  Log() << kFATAL << "MethodCrossValidation supports XML reading only for deterministic splitting !" << Endl;
231  }
232 }
233 
234 ////////////////////////////////////////////////////////////////////////////////
235 /// Read the weights
236 ///
237 
239 {
240  Log() << kFATAL << "CrossValidation currently supports only reading from XML." << Endl;
241 }
242 
243 ////////////////////////////////////////////////////////////////////////////////
244 ///
245 
247 {
248  const Event *ev = GetEvent();
249 
250  if (fOutputEnsembling == "None") {
251  if (fSplitExpr != nullptr) {
252  // K-folds with a deterministic split
253  UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
254  return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
255  } else {
256  // K-folds with a random split was used
257  UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
258  return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
259  }
260  } else if (fOutputEnsembling == "Avg") {
261  Double_t val = 0.0;
262  for (auto &m : fEncapsulatedMethods) {
263  val += m->GetMvaValue(err, errUpper);
264  }
265  return val / fEncapsulatedMethods.size();
266  } else {
267  Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
268  return 0; // Cannot happen
269  }
270 }
271 
272 ////////////////////////////////////////////////////////////////////////////////
273 /// Get the multiclass MVA response.
274 
276 {
277  const Event *ev = GetEvent();
278 
279  if (fOutputEnsembling == "None") {
280  if (fSplitExpr != nullptr) {
281  // K-folds with a deterministic split
282  UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
283  return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
284  } else {
285  // K-folds with a random split was used
286  UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
287  return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
288  }
289  } else if (fOutputEnsembling == "Avg") {
290 
291  for (auto &e : fMulticlassValues) {
292  e = 0;
293  }
294 
295  for (auto &m : fEncapsulatedMethods) {
296  auto methodValues = m->GetMulticlassValues();
297  for (size_t i = 0; i < methodValues.size(); ++i) {
298  fMulticlassValues[i] += methodValues[i];
299  }
300  }
301 
302  for (auto &e : fMulticlassValues) {
303  e /= fEncapsulatedMethods.size();
304  }
305 
306  return fMulticlassValues;
307 
308  } else {
309  Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
310  return fMulticlassValues; // Cannot happen
311  }
312 }
313 
314 ////////////////////////////////////////////////////////////////////////////////
315 /// Get the regression value generated by the containing methods.
316 
318 {
319  const Event *ev = GetEvent();
320 
321  if (fOutputEnsembling == "None") {
322  if (fSplitExpr != nullptr) {
323  // K-folds with a deterministic split
324  UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
325  return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
326  } else {
327  // K-folds with a random split was used
328  UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
329  return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
330  }
331  } else if (fOutputEnsembling == "Avg") {
332 
333  for (auto &e : fRegressionValues) {
334  e = 0;
335  }
336 
337  for (auto &m : fEncapsulatedMethods) {
338  auto methodValues = m->GetRegressionValues();
339  for (size_t i = 0; i < methodValues.size(); ++i) {
340  fRegressionValues[i] += methodValues[i];
341  }
342  }
343 
344  for (auto &e : fRegressionValues) {
345  e /= fEncapsulatedMethods.size();
346  }
347 
348  return fRegressionValues;
349 
350  } else {
351  Log() << kFATAL << "Ensembling type " << fOutputEnsembling << " unknown" << Endl;
352  return fRegressionValues; // Cannot happen
353  }
354 }
355 
356 ////////////////////////////////////////////////////////////////////////////////
357 ///
358 
360 {
361  // // Used for evaluation, which is outside the life time of MethodCrossEval.
362  // Log() << kFATAL << "Method CrossValidation should not be created manually,"
363  // " only as part of using TMVA::Reader." << Endl;
364  // return;
365 }
366 
367 ////////////////////////////////////////////////////////////////////////////////
368 ///
369 
371 {
372  Log() << kWARNING
373  << "Method CrossValidation should not be created manually,"
374  " only as part of using TMVA::Reader."
375  << Endl;
376 }
377 
378 ////////////////////////////////////////////////////////////////////////////////
379 ///
380 
382 {
383  return nullptr;
384 }
385 
386 ////////////////////////////////////////////////////////////////////////////////
387 
389  UInt_t /*numberTargets*/)
390 {
391  return kTRUE;
392  // if (fEncapsulatedMethods.size() == 0) {return kFALSE;}
393  // if (fEncapsulatedMethods.at(0) == nullptr) {return kFALSE;}
394  // return fEncapsulatedMethods.at(0)->HasAnalysisType(type, numberClasses, numberTargets);
395 }
396 
397 ////////////////////////////////////////////////////////////////////////////////
398 /// Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
399 
400 void TMVA::MethodCrossValidation::MakeClassSpecific(std::ostream & /*fout*/, const TString & /*className*/) const
401 {
402  Log() << kWARNING << "MakeClassSpecific not implemented for CrossValidation" << Endl;
403 }
404 
405 ////////////////////////////////////////////////////////////////////////////////
406 /// Specific class header.
407 
408 void TMVA::MethodCrossValidation::MakeClassSpecificHeader(std::ostream & /*fout*/, const TString & /*className*/) const
409 {
410  Log() << kWARNING << "MakeClassSpecificHeader not implemented for CrossValidation" << Endl;
411 }
m
auto * m
Definition: textangle.C:8
TMVA::MethodCrossValidation::MakeClassSpecific
void MakeClassSpecific(std::ostream &, const TString &) const
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).
Definition: MethodCrossValidation.cxx:400
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
e
#define e(i)
Definition: RSha256.hxx:103
TMVA::MethodCrossValidation::WriteMonitoringHistosToFile
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file dummy implementation here --------------—
Definition: MethodCrossValidation.cxx:359
TMVA::MethodCrossValidation::DeclareCompatibilityOptions
void DeclareCompatibilityOptions()
Options that are used ONLY for the READER to ensure backward compatibility.
Definition: MethodCrossValidation.cxx:69
TMVA::MethodCrossValidation::InstantiateMethodFromXML
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
Definition: MethodCrossValidation.cxx:154
TString::Data
const char * Data() const
Definition: TString.h:369
ClassImp
#define ClassImp(name)
Definition: Rtypes.h:364
Form
char * Form(const char *fmt,...)
TMVA::Ranking
Ranking for variables in method (implementation)
Definition: Ranking.h:48
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TMVA::MethodBase::DeclareCompatibilityOptions
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:596
TMVA::Tools::AddChild
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
TMVA::MethodCrossValidation
Definition: MethodCrossValidation.h:38
TMVA::MethodCrossValidation::GetHelpMessage
void GetHelpMessage() const
Definition: MethodCrossValidation.cxx:370
MethodCrossValidation.h
TMVA::MethodCrossValidation::ReadWeightsFromXML
void ReadWeightsFromXML(void *parent)
Reads from the xml file.
Definition: MethodCrossValidation.cxx:208
CvSplit.h
TString
Basic string class.
Definition: TString.h:136
TMVA::MethodCrossValidation::Reset
void Reset(void)
Reset the method, as if it had just been instantiated (forget all training etc.).
Definition: MethodCrossValidation.cxx:108
TMVA::MethodCrossValidation::~MethodCrossValidation
virtual ~MethodCrossValidation(void)
Destructor.
Definition: MethodCrossValidation.cxx:50
REGISTER_METHOD
#define REGISTER_METHOD(CLASS)
for example
Definition: ClassifierFactory.h:124
TSystem::GetDirName
virtual TString GetDirName(const char *pathname)
Return the directory name in pathname.
Definition: TSystem.cxx:1030
bool
TMVA::MethodCrossValidation::ReadWeightsFromStream
virtual void ReadWeightsFromStream(std::istream &)=0
TMVA::ClassifierFactory::Instance
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
Definition: ClassifierFactory.cxx:48
TMVA::Tools::AddAttr
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
TMVA::CvSplitKFoldsExpr
Definition: CvSplit.h:64
TMVA::DataSetInfo
Class that contains all the data information.
Definition: DataSetInfo.h:62
TMVA::MethodCrossValidation::GetRegressionValues
const std::vector< Float_t > & GetRegressionValues()
Get the regression value generated by the containing methods.
Definition: MethodCrossValidation.cxx:317
TSystem.h
TMVA::Types::EAnalysisType
EAnalysisType
Definition: Types.h:128
TMVA::MethodCrossValidation::Train
void Train(void)
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
Definition: MethodCrossValidation.cxx:146
TMVA::gConfig
Config & gConfig()
TMVA::MethodBase::MethodCrossValidation
friend class MethodCrossValidation
Definition: MethodBase.h:117
TMVA::Tools::ReadAttr
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
TMVA::Types::kCategory
@ kCategory
Definition: Types.h:99
TMVA::Config::IONames::fWeightFileDir
TString fWeightFileDir
Definition: Config.h:126
TMVA::MethodCrossValidation::ProcessOptions
void ProcessOptions()
The option string is decoded, for available options see "DeclareOptions".
Definition: MethodCrossValidation.cxx:77
TMVA::MethodCrossValidation::GetWeightFileNameForFold
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
Definition: MethodCrossValidation.cxx:114
TMVA::MethodBase
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
TMVA::Types
Singleton class for Global types used by TMVA.
Definition: Types.h:73
Types.h
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Config.h
unsigned int
TMVA::MethodCrossValidation::CreateRanking
const Ranking * CreateRanking()
Definition: MethodCrossValidation.cxx:381
TMVA::MethodCrossValidation::MakeClassSpecificHeader
void MakeClassSpecificHeader(std::ostream &, const TString &) const
Specific class header.
Definition: MethodCrossValidation.cxx:408
gSystem
R__EXTERN TSystem * gSystem
Definition: TSystem.h:559
TMVA::MethodCrossValidation::Init
void Init(void)
Common initialisation with defaults for the Method.
Definition: MethodCrossValidation.cxx:99
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::MethodCrossValidation::HasAnalysisType
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Definition: MethodCrossValidation.cxx:388
TMVA::MethodCrossValidation::AddWeightsXMLTo
void AddWeightsXMLTo(void *parent) const
Write weights to XML.
Definition: MethodCrossValidation.cxx:178
TMVA::Config::GetIONames
IONames & GetIONames()
Definition: Config.h:100
TMVA::Event
Definition: Event.h:51
TMVA::MethodCrossValidation::DeclareOptions
void DeclareOptions()
Definition: MethodCrossValidation.cxx:54
Tools.h
ClassifierFactory.h
TMVA::gTools
Tools & gTools()
MethodCategory.h
TMVA::ClassifierFactory::Create
IMethod * Create(const std::string &name, const TString &job, const TString &title, DataSetInfo &dsi, const TString &option)
creates the method if needed based on the method name using the creator function the factory has stor...
Definition: ClassifierFactory.cxx:89
TMVA
create variable transformations
Definition: GeneticMinimizer.h:22
TMVA::MethodCrossValidation::GetMulticlassValues
const std::vector< Float_t > & GetMulticlassValues()
Get the multiclass MVA response.
Definition: MethodCrossValidation.cxx:275
TMVA::MethodCrossValidation::GetMvaValue
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Definition: MethodCrossValidation.cxx:246