Logo ROOT   6.14/05
Reference Guide
MethodCrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #ifndef ROOT_TMVA_MethodCrossValidation
13 #define ROOT_TMVA_MethodCrossValidation
14 
15 //////////////////////////////////////////////////////////////////////////
16 // //
17 // MethodCrossValidation //
18 // //
19 //////////////////////////////////////////////////////////////////////////
20 
21 #include "TMVA/CvSplit.h"
22 #include "TMVA/DataSetInfo.h"
23 #include "TMVA/MethodBase.h"
24 
25 #include "TString.h"
26 
27 #include <iostream>
28 #include <memory>
29 
30 namespace TMVA {
31 
32 class CrossValidation;
33 class Ranking;
34 
35 // Looks for serialised methods of the form methodTitle + "_fold" + iFold;
37 
39 
40 public:
41  // constructor for training and reading
42  MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData,
43  const TString &theOption = "");
44 
45  // constructor for calculating BDT-MVA using previously generatad decision trees
46  MethodCrossValidation(DataSetInfo &theData, const TString &theWeightFile);
47 
48  virtual ~MethodCrossValidation(void);
49 
50  // optimize tuning parameters
51  // virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString
52  // fitType="FitGA"); virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
53 
54  // training method
55  void Train(void);
56 
57  // revoke training
58  void Reset(void);
59 
61 
62  // write weights to file
63  void AddWeightsXMLTo(void *parent) const;
64 
65  // read weights from file
66  void ReadWeightsFromStream(std::istream &istr);
67  void ReadWeightsFromXML(void *parent);
68 
69  // write method specific histos to target file
70  void WriteMonitoringHistosToFile(void) const;
71 
72  // calculate the MVA value
73  Double_t GetMvaValue(Double_t *err = 0, Double_t *errUpper = 0);
74  const std::vector<Float_t> &GetMulticlassValues();
75  const std::vector<Float_t> &GetRegressionValues();
76 
77  // the option handling methods
78  void DeclareOptions();
79  void ProcessOptions();
80 
81  // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
82  void MakeClassSpecific(std::ostream &, const TString &) const;
83  void MakeClassSpecificHeader(std::ostream &, const TString &) const;
84 
85  void GetHelpMessage() const;
86 
87  const Ranking *CreateRanking();
88  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
89 
90 protected:
91  void Init(void);
93 
94 private:
96  MethodBase *InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const;
97 
98 private:
103 
105  std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
106 
107  std::vector<Float_t> fMulticlassValues;
108  std::vector<Float_t> fRegressionValues;
109 
110  std::vector<MethodBase *> fEncapsulatedMethods;
111 
112  // Used for CrossValidation with random splits (not using the
113  // CVSplitCrossValisationExpr functionality) to communicate Event to fold
114  // mapping.
115  std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
116 
117  // for backward compatibility
119 };
120 
121 } // namespace TMVA
122 
123 #endif
void ReadWeightsFromStream(std::istream &istr)
Read the weights.
MethodBase * InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const
Reads in a weight file an instantiates the corresponding method.
void WriteMonitoringHistosToFile(void) const
write special monitoring histograms to file dummy implementation here --------------— ...
void ReadWeightsFromXML(void *parent)
Reads from the xml file.
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
EAnalysisType
Definition: Types.h:127
Virtual base Class for all MVA method.
Definition: MethodBase.h:109
void DeclareCompatibilityOptions()
Options that are used ONLY for the READER to ensure backward compatibility.
Basic string class.
Definition: TString.h:131
Ranking for variables in method (implementation)
Definition: Ranking.h:48
bool Bool_t
Definition: RtypesCore.h:59
void Reset(void)
Reset the method, as if it had just been instantiated (forget all training etc.). ...
const std::vector< Float_t > & GetRegressionValues()
Get the regression value generated by the containing methods.
#define ClassDef(name, id)
Definition: Rtypes.h:320
MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
std::vector< Float_t > fMulticlassValues
Class that contains all the data information.
Definition: DataSetInfo.h:60
std::unique_ptr< CvSplitKFoldsExpr > fSplitExpr
virtual ~MethodCrossValidation(void)
Destructor.
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
unsigned int UInt_t
Definition: RtypesCore.h:42
void ProcessOptions()
The option string is decoded, for available options see "DeclareOptions".
void Train(void)
Call the Optimizer with the set of parameters and ranges that are meant to be tuned.
void MakeClassSpecificHeader(std::ostream &, const TString &) const
Specific class header.
void Init(void)
Common initialisation with defaults for the Method.
double Double_t
Definition: RtypesCore.h:55
int type
Definition: TGX11.cxx:120
std::vector< Float_t > fRegressionValues
TString GetWeightFileNameForFold(UInt_t iFold) const
Returns filename of weight file for a given fold.
Abstract ClassifierFactory template that handles arbitrary types.
void AddWeightsXMLTo(void *parent) const
Write weights to XML.
const std::vector< Float_t > & GetMulticlassValues()
Get the multiclass MVA response.
std::vector< MethodBase * > fEncapsulatedMethods
virtual void ReadWeightsFromStream(std::istream &)=0
std::map< const TMVA::Event *, UInt_t > fEventToFoldMapping
void MakeClassSpecific(std::ostream &, const TString &) const
Make ROOT-independent C++ class for classifier response (classifier-specific implementation).