Logo ROOT   6.12/07
Reference Guide
MethodPyAdaBoost.h
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodPyAdaBoost *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * scikit-learn package AdaBoostClassifier method based on python *
12  * *
13  **********************************************************************************/
14 
15 #ifndef ROOT_TMVA_MethodPyAdaBoost
16 #define ROOT_TMVA_MethodPyAdaBoost
17 
18 //////////////////////////////////////////////////////////////////////////
19 // //
20 // MethodPyAdaBoost //
21 // //
22 //////////////////////////////////////////////////////////////////////////
23 
24 #include "TMVA/PyMethodBase.h"
25 
26 #include "TString.h"
27 
28 namespace TMVA {
29 
30  class Factory;
31  class Reader;
32  class DataSetManager;
33  class Types;
34  class MethodPyAdaBoost : public PyMethodBase {
35 
36  public :
37  MethodPyAdaBoost(const TString &jobName,
38  const TString &methodTitle,
39  DataSetInfo &theData,
40  const TString &theOption = "");
41 
43  const TString &theWeightFile);
44 
46 
47  void Train();
48 
49  void Init();
50  void DeclareOptions();
51  void ProcessOptions();
52 
53  // create ranking
54  const Ranking *CreateRanking();
55 
56  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
57 
58  // performs classifier testing
59  virtual void TestClassification();
60 
61  Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
62  std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
63  std::vector<Float_t>& GetMulticlassValues();
64 
65  virtual void ReadModelFromFile();
66 
68  // the actual "weights"
69  virtual void AddWeightsXMLTo(void * /*parent */ ) const {} // = 0;
70  virtual void ReadWeightsFromXML(void * /*wghtnode*/ ) {} // = 0;
71  virtual void ReadWeightsFromStream(std::istream &) {} //= 0; backward compatibility
72 
73  private :
75  friend class Factory;
76  friend class Reader;
77 
78  protected:
79  std::vector<Double_t> mvaValues;
80  std::vector<Float_t> classValues;
81 
82  UInt_t fNvars; // number of variables
83  UInt_t fNoutputs; // number of outputs
84  TString fFilenameClassifier; // Path to serialized classifier (default in `weights` folder)
85 
86  //AdaBoost options
87 
89  TString fBaseEstimator; //object, optional (default=DecisionTreeClassifier)
90  //The base estimator from which the boosted ensemble is built.
91  //Support for sample weighting is required, as well as proper `classes_`
92  //and `n_classes_` attributes.
93 
95  Int_t fNestimators; //integer, optional (default=10)
96  //The number of trees in the forest.
97 
99  Double_t fLearningRate; //loat, optional (default=1.)
100  //Learning rate shrinks the contribution of each classifier by
101  //``learning_rate``. There is a trade-off between ``learning_rate`` and ``n_estimators``.
102 
104  TString fAlgorithm; //{'SAMME', 'SAMME.R'}, optional (default='SAMME.R')
105  //If 'SAMME.R' then use the SAMME.R real boosting algorithm.
106  //``base_estimator`` must support calculation of class probabilities.
107  //If 'SAMME' then use the SAMME discrete boosting algorithm.
108  //The SAMME.R algorithm typically converges faster than SAMME,
109  //achieving a lower test error with fewer boosting iterations.
110 
112  TString fRandomState; //int, RandomState instance or None, optional (default=None)
113  //If int, random_state is the seed used by the random number generator;
114  //If RandomState instance, random_state is the random number generator;
115  //If None, the random number generator is the RandomState instance used by `np.random`.
116 
117  // get help message text
118  void GetHelpMessage() const;
119 
121  };
122 
123 } // namespace TMVA
124 
125 #endif // ROOT_TMVA_MethodPyAdaBoost
long long Long64_t
Definition: RtypesCore.h:69
EAnalysisType
Definition: Types.h:125
Basic string class.
Definition: TString.h:125
Ranking for variables in method (implementation)
Definition: Ranking.h:48
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
virtual void TestClassification()
initialization
#define ClassDef(name, id)
Definition: Rtypes.h:320
DataSetManager * fDataSetManager
std::vector< Float_t > & GetMulticlassValues()
Class that contains all the data information.
Definition: DataSetInfo.h:60
Double_t GetMvaValue(Double_t *errLower=0, Double_t *errUpper=0)
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
virtual void ReadWeightsFromXML(void *)
virtual void ReadWeightsFromStream(std::istream &)
unsigned int UInt_t
Definition: RtypesCore.h:42
This is the main MVA steering class.
Definition: Factory.h:81
virtual void AddWeightsXMLTo(void *) const
double Double_t
Definition: RtypesCore.h:55
Class that contains all the data information.
int type
Definition: TGX11.cxx:120
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
Abstract ClassifierFactory template that handles arbitrary types.
std::vector< Float_t > classValues
virtual void ReadModelFromFile()
virtual void ReadWeightsFromStream(std::istream &)=0
std::vector< Double_t > mvaValues
MethodPyAdaBoost(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
const Ranking * CreateRanking()
_object PyObject
Definition: TPyArg.h:20