Logo ROOT   6.14/05
Reference Guide
MethodPyRandomForest.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyRandomForest *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * scikit-learn Package RandomForestClassifier method based on python *
12  * *
13  **********************************************************************************/
14 
15 #ifndef ROOT_TMVA_MethodPyRandomForest
16 #define ROOT_TMVA_MethodPyRandomForest
17 
18 //////////////////////////////////////////////////////////////////////////
19 // //
20 // MethodPyRandomForest //
21 // //
22 //////////////////////////////////////////////////////////////////////////
23 
24 #include "TMVA/PyMethodBase.h"
25 
26 namespace TMVA {
27 
28  class Factory; // DSMTEST
29  class Reader; // DSMTEST
30  class DataSetManager; // DSMTEST
31  class Types;
33 
34  public :
35  // constructors
36  MethodPyRandomForest(const TString &jobName,
37  const TString &methodTitle,
38  DataSetInfo &theData,
39  const TString &theOption = "");
40 
42  const TString &theWeightFile);
43 
45  void Train();
46 
47  // options treatment
48  void Init();
49  void DeclareOptions();
50  void ProcessOptions();
51 
52  // create ranking
53  const Ranking *CreateRanking();
54 
55  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
56 
57  // performs classifier testing
58  virtual void TestClassification();
59 
60  // Get class probabilities of given event
61  Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
62  std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
63  std::vector<Float_t>& GetMulticlassValues();
64 
66  // the actual "weights"
67  virtual void AddWeightsXMLTo(void * /* parent */) const {} // = 0;
68  virtual void ReadWeightsFromXML(void * /* wghtnode */) {} // = 0;
69  virtual void ReadWeightsFromStream(std::istream &) {} //= 0; // backward compatibility
70 
71  void ReadModelFromFile();
72 
73  private :
75  friend class Factory; // DSMTEST
76  friend class Reader; // DSMTEST
77 
78  protected:
79  std::vector<Double_t> mvaValues;
80  std::vector<Float_t> classValues;
81 
82  UInt_t fNvars; // number of variables
83  UInt_t fNoutputs; // number of outputs
84  TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder)
85 
86  // RandomForest options
87 
89  Int_t fNestimators; //integer, optional (default=10)
90  //The number of trees in the forest.
91 
93  TString fCriterion; //string, optional (default="gini")
94  //The function to measure the quality of a split. Supported criteria are
95  //"gini" for the Gini impurity and "entropy" for the information gain.
96  //Note: this parameter is tree-specific.
97 
99  TString fMaxDepth; //integer or None, optional (default=None)
100  //The maximum depth of the tree. If None, then nodes are expanded until
101  //all leaves are pure or until all leaves contain less than `fMinSamplesSplit`.
102 
104  Int_t fMinSamplesSplit; //integer, optional (default=2)
105  //The minimum number of samples required to split an internal node.
106 
108  Int_t fMinSamplesLeaf; //integer, optional (default=1)
109  //The minimum number of samples in newly created leaves. A split is
110  //discarded if after the split, one of the leaves would contain less then
111  //``min_samples_leaf`` samples.
112  //Note: this parameter is tree-specific.
113 
115  Double_t fMinWeightFractionLeaf; //float, optional (default=0.)
116  //The minimum weighted fraction of the input samples required to be at a
117  //leaf node.
118  //Note: this parameter is tree-specific.
119 
121  TString fMaxFeatures; //int, float, string or None, optional (default="auto")
122  //The number of features to consider when looking for the best split:
123  //- If int, then consider `max_features` features at each split.
124  //- If float, then `max_features` is a percentage and
125  //`int(max_features * n_features)` features are considered at each split.
126  //- If "auto", then `max_features=sqrt(n_features)`.
127  //- If "sqrt", then `max_features=sqrt(n_features)`.
128  //- If "log2", then `max_features=log2(n_features)`.
129  //- If None, then `max_features=n_features`.
130  // Note: the search for a split does not stop until at least one
131  // valid partition of the node samples is found, even if it requires to
132  // effectively inspect more than ``max_features`` features.
133  // Note: this parameter is tree-specific.
134 
136  TString fMaxLeafNodes; //int or None, optional (default=None)
137  //Grow trees with ``max_leaf_nodes`` in best-first fashion.
138  //Best nodes are defined as relative reduction in impurity.
139  //If None then unlimited number of leaf nodes.
140  //If not None then ``max_depth`` will be ignored.
141 
143  Bool_t fBootstrap; //boolean, optional (default=True)
144  //Whether bootstrap samples are used when building trees.
145 
147  Bool_t fOobScore; //Whether to use out-of-bag samples to estimate
148  //the generalization error.
149 
151  Int_t fNjobs; // integer, optional (default=1)
152  //The number of jobs to run in parallel for both `fit` and `predict`.
153  //If -1, then the number of jobs is set to the number of cores.
154 
156  TString fRandomState; //int, RandomState instance or None, optional (default=None)
157  //If int, random_state is the seed used by the random number generator;
158  //If RandomState instance, random_state is the random number generator;
159  //If None, the random number generator is the RandomState instance used
160  //by `np.random`.
161 
163  Int_t fVerbose; //Controls the verbosity of the tree building process.
164 
166  Bool_t fWarmStart; //bool, optional (default=False)
167  //When set to ``True``, reuse the solution of the previous call to fit
168  //and add more estimators to the ensemble, otherwise, just fit a whole
169  //new forest.
170 
172  TString fClassWeight; //dict, list of dicts, "auto", "subsample" or None, optional
173  //Weights associated with classes in the form ``{class_label: weight}``.
174  //If not given, all classes are supposed to have weight one. For
175  //multi-output problems, a list of dicts can be provided in the same
176  //order as the columns of y.
177  //The "auto" mode uses the values of y to automatically adjust
178  //weights inversely proportional to class frequencies in the input data.
179  //The "subsample" mode is the same as "auto" except that weights are
180  //computed based on the bootstrap sample for every tree grown.
181  //For multi-output, the weights of each column of y will be multiplied.
182  //Note that these weights will be multiplied with sample_weight (passed
183  //through the fit method) if sample_weight is specified.
184 
185  // get help message text
186  void GetHelpMessage() const;
187 
189  };
190 
191 } // namespace TMVA
192 
193 #endif // ROOT_TMVA_MethodPyRandomForest
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
std::vector< Double_t > mvaValues
long long Long64_t
Definition: RtypesCore.h:69
virtual void AddWeightsXMLTo(void *) const
EAnalysisType
Definition: Types.h:127
Basic string class.
Definition: TString.h:131
Ranking for variables in method (implementation)
Definition: Ranking.h:48
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
#define ClassDef(name, id)
Definition: Rtypes.h:320
Class that contains all the data information.
Definition: DataSetInfo.h:60
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
MethodPyRandomForest(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
virtual void ReadWeightsFromXML(void *)
unsigned int UInt_t
Definition: RtypesCore.h:42
This is the main MVA steering class.
Definition: Factory.h:81
double Double_t
Definition: RtypesCore.h:55
Class that contains all the data information.
int type
Definition: TGX11.cxx:120
virtual void ReadWeightsFromStream(std::istream &)
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
Abstract ClassifierFactory template that handles arbitrary types.
virtual void TestClassification()
initialization
virtual void ReadWeightsFromStream(std::istream &)=0
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
std::vector< Float_t > classValues
_object PyObject
Definition: TPyArg.h:20
std::vector< Float_t > & GetMulticlassValues()