Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyGTB.h
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyGTB *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * scikit-learn Package GradientBoostingClassifier method based on python *
12 * *
13 **********************************************************************************/
14
15#ifndef ROOT_TMVA_MethodPyGTB
16#define ROOT_TMVA_MethodPyGTB
17
18//////////////////////////////////////////////////////////////////////////
19// //
20// MethodPyGTB //
21// //
22//////////////////////////////////////////////////////////////////////////
23
24#include "TMVA/PyMethodBase.h"
25#include <vector>
26
27namespace TMVA {
28
29 class Factory;
30 class Reader;
31 class DataSetManager;
32 class Types;
33 class MethodPyGTB : public PyMethodBase {
34
35 public :
36 MethodPyGTB(const TString &jobName,
37 const TString &methodTitle,
38 DataSetInfo &theData,
39 const TString &theOption = "");
41 const TString &theWeightFile);
42 ~MethodPyGTB(void);
43
44 void Train();
45 void Init();
46 void DeclareOptions();
47 void ProcessOptions();
48
49 const Ranking *CreateRanking();
50
51 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
52
53 virtual void TestClassification();
54
55 Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
56 std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
57 std::vector<Float_t>& GetMulticlassValues();
58
59 virtual void ReadModelFromFile();
60
62 // the actual "weights"
63 virtual void AddWeightsXMLTo(void * /* parent */ ) const {} // = 0;
64 virtual void ReadWeightsFromXML(void * /*wghtnode*/) {} // = 0;
65 virtual void ReadWeightsFromStream(std::istream &) {} //= 0; backward compatibility
66
67 private :
69 friend class Factory;
70 friend class Reader;
71
72 protected:
73 std::vector<Double_t> mvaValues;
74 std::vector<Float_t> classValues;
75
76 UInt_t fNvars; // number of variables
77 UInt_t fNoutputs; // number of outputs
78 TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder)
79
80 //GTB options
81
83 TString fLoss; // {'deviance', 'exponential'}, optional (default='deviance')
84 //loss function to be optimized. 'deviance' refers to
85 //deviance (= logistic regression) for classification
86 //with probabilistic outputs. For loss 'exponential' gradient
87 //boosting recovers the AdaBoost algorithm.
88
90 Double_t fLearningRate; //float, optional (default=0.1)
91 //learning rate shrinks the contribution of each tree by `learning_rate`.
92 //There is a trade-off between learning_rate and n_estimators.
93
95 Int_t fNestimators; //integer, optional (default=10)
96 //The number of trees in the forest.
97
99 Double_t fSubsample; //float, optional (default=1.0)
100 //The fraction of samples to be used for fitting the individual base
101 //learners. If smaller than 1.0 this results in Stochastic Gradient
102 //Boosting. `subsample` interacts with the parameter `n_estimators`.
103 //Choosing `subsample < 1.0` leads to a reduction of variance
104 //and an increase in bias.
105
107 Int_t fMinSamplesSplit; // integer, optional (default=2)
108 //The minimum number of samples required to split an internal node.
109
111 Int_t fMinSamplesLeaf; //integer, optional (default=1)
112 //The minimum number of samples required to be at a leaf node.
113
115 Double_t fMinWeightFractionLeaf; //float, optional (default=0.)
116 //The minimum weighted fraction of the input samples required to be at a leaf node.
117
119 Int_t fMaxDepth; //integer, optional (default=3)
120 //maximum depth of the individual regression estimators. The maximum
121 //depth limits the number of nodes in the tree. Tune this parameter
122 //for best performance; the best value depends on the interaction
123 //of the input variables.
124 //Ignored if ``max_leaf_nodes`` is not None.
125
127 TString fInit; //BaseEstimator, None, optional (default=None)
128 //An estimator object that is used to compute the initial
129 //predictions. ``init`` has to provide ``fit`` and ``predict``.
130 //If None it uses ``loss.init_estimator``.
131
133 TString fRandomState; //int, RandomState instance or None, optional (default=None)
134 //If int, random_state is the seed used by the random number generator;
135 //If RandomState instance, random_state is the random number generator;
136 //If None, the random number generator is the RandomState instance used
137 //by `np.random`.
138
140 TString fMaxFeatures; //int, float, string or None, optional (default="auto")
141 //The number of features to consider when looking for the best split:
142 //- If int, then consider `max_features` features at each split.
143 //- If float, then `max_features` is a percentage and
144 //`int(max_features * n_features)` features are considered at each split.
145 //- If "auto", then `max_features=sqrt(n_features)`.
146 //- If "sqrt", then `max_features=sqrt(n_features)`.
147 //- If "log2", then `max_features=log2(n_features)`.
148 //- If None, then `max_features=n_features`.
149 // Note: the search for a split does not stop until at least one
150 // valid partition of the node samples is found, even if it requires to
151 // effectively inspect more than ``max_features`` features.
152 // Note: this parameter is tree-specific.
153
155 Int_t fVerbose; //Controls the verbosity of the tree building process.
156
158 TString fMaxLeafNodes; //int or None, optional (default=None)
159 //Grow trees with ``max_leaf_nodes`` in best-first fashion.
160 //Best nodes are defined as relative reduction in impurity.
161 //If None then unlimited number of leaf nodes.
162 //If not None then ``max_depth`` will be ignored.
163
165 Bool_t fWarmStart; //bool, optional (default=False)
166 //When set to ``True``, reuse the solution of the previous call to fit
167 //and add more estimators to the ensemble, otherwise, just fit a whole
168 //new forest.
169
170 // get help message text
171 void GetHelpMessage() const;
172
174 };
175
176} // namespace TMVA
177
178#endif // ROOT_TMVA_PyMethodGTB
_object PyObject
double Double_t
Definition RtypesCore.h:59
long long Long64_t
Definition RtypesCore.h:73
#define ClassDef(name, id)
Definition Rtypes.h:325
int type
Definition TGX11.cxx:121
Class that contains all the data information.
Definition DataSetInfo.h:62
Class that contains all the data information.
This is the main MVA steering class.
Definition Factory.h:80
virtual void ReadWeightsFromStream(std::istream &)=0
Double_t fSubsample
Definition MethodPyGTB.h:99
DataSetManager * fDataSetManager
Definition MethodPyGTB.h:68
PyObject * pMinSamplesLeaf
Double_t fMinWeightFractionLeaf
std::vector< Double_t > mvaValues
Definition MethodPyGTB.h:73
PyObject * pMaxFeatures
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
PyObject * pMaxDepth
PyObject * pMaxLeafNodes
std::vector< Float_t > classValues
Definition MethodPyGTB.h:74
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
Double_t fLearningRate
Definition MethodPyGTB.h:90
void GetHelpMessage() const
PyObject * pLearningRate
Definition MethodPyGTB.h:89
const Ranking * CreateRanking()
virtual void TestClassification()
initialization
std::vector< Float_t > & GetMulticlassValues()
virtual void ReadWeightsFromStream(std::istream &)
Definition MethodPyGTB.h:65
virtual void ReadModelFromFile()
PyObject * pNestimators
Definition MethodPyGTB.h:94
PyObject * pVerbose
virtual void AddWeightsXMLTo(void *) const
Definition MethodPyGTB.h:63
virtual void ReadWeightsFromXML(void *)
Definition MethodPyGTB.h:64
TString fFilenameClassifier
Definition MethodPyGTB.h:78
PyObject * pMinSamplesSplit
PyObject * pLoss
Definition MethodPyGTB.h:82
PyObject * pSubsample
Definition MethodPyGTB.h:98
PyObject * pRandomState
PyObject * pWarmStart
PyObject * pMinWeightFractionLeaf
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
Ranking for variables in method (implementation)
Definition Ranking.h:48
The Reader class serves to use the MVAs in a specific analysis context.
Definition Reader.h:64
Basic string class.
Definition TString.h:136
create variable transformations