Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodPyAdaBoost.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodPyAdaBoost *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * AdaBoost Classifier from Scikit learn *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (see tmva/doc/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <Python.h> // Needs to be included first to avoid redefinition of _POSIX_C_SOURCE
22
23#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
24#include <numpy/arrayobject.h>
25
26#include "TMVA/Config.h"
27#include "TMVA/Configurable.h"
29#include "TMVA/DataSet.h"
30#include "TMVA/Event.h"
31#include "TMVA/IMethod.h"
32#include "TMVA/MsgLogger.h"
33#include "TMVA/PDF.h"
34#include "TMVA/Ranking.h"
35#include "TMVA/Tools.h"
36#include "TMVA/Types.h"
37#include "TMVA/Timer.h"
39#include "TMVA/Results.h"
40
41#include "TMatrix.h"
42
43using namespace TMVA;
44
45namespace TMVA {
46namespace Internal {
47class PyGILRAII {
48 PyGILState_STATE m_GILState;
49
50public:
51 PyGILRAII() : m_GILState(PyGILState_Ensure()) {}
52 ~PyGILRAII() { PyGILState_Release(m_GILState); }
53};
54} // namespace Internal
55} // namespace TMVA
56
57REGISTER_METHOD(PyAdaBoost)
58
60
61//_______________________________________________________________________
63 const TString &methodTitle,
64 DataSetInfo &dsi,
65 const TString &theOption) :
66 PyMethodBase(jobName, Types::kPyAdaBoost, methodTitle, dsi, theOption),
67 fBaseEstimator("None"),
68 fNestimators(50),
69 fLearningRate(1.0),
70 fAlgorithm("SAMME"),
71 fRandomState("None")
72{
73}
74
75//_______________________________________________________________________
77 const TString &theWeightFile) :
78 PyMethodBase(Types::kPyAdaBoost, theData, theWeightFile),
79 fBaseEstimator("None"),
80 fNestimators(50),
81 fLearningRate(1.0),
82 fAlgorithm("SAMME"),
83 fRandomState("None")
84{
85}
86
87//_______________________________________________________________________
89{
90}
91
92//_______________________________________________________________________
94{
95 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
96 if (type == Types::kMulticlass && numberClasses >= 2) return kTRUE;
97 return kFALSE;
98}
99
100//_______________________________________________________________________
102{
104
105 DeclareOptionRef(fBaseEstimator, "BaseEstimator", "object, optional (default=DecisionTreeClassifier)\
106 The base estimator from which the boosted ensemble is built.\
107 Support for sample weighting is required, as well as proper `classes_`\
108 and `n_classes_` attributes.");
109
110 DeclareOptionRef(fNestimators, "NEstimators", "integer, optional (default=50)\
111 The maximum number of estimators at which boosting is terminated.\
112 In case of perfect fit, the learning procedure is stopped early.");
113
114 DeclareOptionRef(fLearningRate, "LearningRate", "float, optional (default=1.)\
115 Learning rate shrinks the contribution of each classifier by\
116 ``learning_rate``. There is a trade-off between ``learning_rate`` and\
117 ``n_estimators``.");
118
119 DeclareOptionRef(fAlgorithm, "Algorithm", "{'SAMME', 'SAMME.R'}, optional (default='SAMME')\
120 If 'SAMME.R' then use the SAMME.R real boosting algorithm.\
121 ``base_estimator`` must support calculation of class probabilities.\
122 If 'SAMME' then use the SAMME discrete boosting algorithm.\
123 The SAMME.R algorithm typically converges faster than SAMME,\
124 achieving a lower test error with fewer boosting iterations.\
125 'SAME.R' is deprecated since version 1.4 and removed since 1.6");
126
127 DeclareOptionRef(fRandomState, "RandomState", "int, RandomState instance or None, optional (default=None)\
128 If int, random_state is the seed used by the random number generator;\
129 If RandomState instance, random_state is the random number generator;\
130 If None, the random number generator is the RandomState instance used\
131 by `np.random`.");
132
133 DeclareOptionRef(fFilenameClassifier, "FilenameClassifier",
134 "Store trained classifier in this file");
135}
136
137//_______________________________________________________________________
138// Check options and load them to local python namespace
140{
142 if (!pBaseEstimator) {
143 Log() << kFATAL << Form("BaseEstimator = %s ... that does not work!", fBaseEstimator.Data())
144 << " The options are Object or None." << Endl;
145 }
146 PyDict_SetItemString(fLocalNS, "baseEstimator", pBaseEstimator);
147
148 if (fNestimators <= 0) {
149 Log() << kFATAL << "NEstimators <=0 ... that does not work!" << Endl;
150 }
152 PyDict_SetItemString(fLocalNS, "nEstimators", pNestimators);
153
154 if (fLearningRate <= 0) {
155 Log() << kFATAL << "LearningRate <=0 ... that does not work!" << Endl;
156 }
158 PyDict_SetItemString(fLocalNS, "learningRate", pLearningRate);
159
160 if (fAlgorithm != "SAMME" && fAlgorithm != "SAMME.R") {
161 Log() << kFATAL << Form("Algorithm = %s ... that does not work!", fAlgorithm.Data())
162 << " The options are SAMME of SAMME.R." << Endl;
163 }
164 pAlgorithm = Eval(Form("'%s'", fAlgorithm.Data()));
165 PyDict_SetItemString(fLocalNS, "algorithm", pAlgorithm);
166
168 if (!pRandomState) {
169 Log() << kFATAL << Form(" RandomState = %s... that does not work !! ", fRandomState.Data())
170 << "If int, random_state is the seed used by the random number generator;"
171 << "If RandomState instance, random_state is the random number generator;"
172 << "If None, the random number generator is the RandomState instance used by `np.random`." << Endl;
173 }
174 PyDict_SetItemString(fLocalNS, "randomState", pRandomState);
175
176 // If no filename is given, set default
178 fFilenameClassifier = GetWeightFileDir() + "/PyAdaBoostModel_" + GetName() + ".PyData";
179 }
180}
181
182//_______________________________________________________________________
184{
186 _import_array(); //require to use numpy arrays
187
188 // Check options and load them to local python namespace
190
191 // Import module for ada boost classifier
192 PyRunString("import sklearn.ensemble");
193
194 // Get data properties
197}
198
199//_______________________________________________________________________
201{
202 // Load training data (data, classes, weights) to python arrays
203 int fNrowsTraining = Data()->GetNTrainingEvents(); //every row is an event, a class type and a weight
204 npy_intp dimsData[2];
205 dimsData[0] = fNrowsTraining;
206 dimsData[1] = fNvars;
207 PyArrayObject * fTrainData = (PyArrayObject *)PyArray_SimpleNew(2, dimsData, NPY_FLOAT);
208 PyDict_SetItemString(fLocalNS, "trainData", (PyObject*)fTrainData);
209 float *TrainData = (float *)(PyArray_DATA(fTrainData));
210
211 npy_intp dimsClasses = (npy_intp) fNrowsTraining;
212 PyArrayObject * fTrainDataClasses = (PyArrayObject *)PyArray_SimpleNew(1, &dimsClasses, NPY_FLOAT);
213 PyDict_SetItemString(fLocalNS, "trainDataClasses", (PyObject*)fTrainDataClasses);
214 float *TrainDataClasses = (float *)(PyArray_DATA(fTrainDataClasses));
215
216 PyArrayObject * fTrainDataWeights = (PyArrayObject *)PyArray_SimpleNew(1, &dimsClasses, NPY_FLOAT);
217 PyDict_SetItemString(fLocalNS, "trainDataWeights", (PyObject*)fTrainDataWeights);
218 float *TrainDataWeights = (float *)(PyArray_DATA(fTrainDataWeights));
219
220 for (int i = 0; i < fNrowsTraining; i++) {
221 // Fill training data matrix
222 const TMVA::Event *e = Data()->GetTrainingEvent(i);
223 for (UInt_t j = 0; j < fNvars; j++) {
224 TrainData[j + i * fNvars] = e->GetValue(j);
225 }
226
227 // Fill target classes
228 TrainDataClasses[i] = e->GetClass();
229
230 // Get event weight
231 TrainDataWeights[i] = e->GetWeight();
232 }
233
234 // Create classifier object
235 PyRunString("classifier = sklearn.ensemble.AdaBoostClassifier(estimator=baseEstimator, n_estimators=nEstimators, learning_rate=learningRate, algorithm=algorithm, random_state=randomState)",
236 "Failed to setup classifier");
237
238 // Fit classifier
239 // NOTE: We dump the output to a variable so that the call does not pollute stdout
240 PyRunString("dump = classifier.fit(trainData, trainDataClasses, trainDataWeights)", "Failed to train classifier");
241
242 // Store classifier
243 fClassifier = PyDict_GetItemString(fLocalNS, "classifier");
244 if(fClassifier == 0) {
245 Log() << kFATAL << "Can't create classifier object from AdaBoostClassifier" << Endl;
246 Log() << Endl;
247 }
248
249 if (IsModelPersistence()) {
250 Log() << Endl;
251 Log() << gTools().Color("bold") << "Saving state file: " << gTools().Color("reset") << fFilenameClassifier << Endl;
252 Log() << Endl;
254 }
255}
256
257//_______________________________________________________________________
259{
261}
262
263//_______________________________________________________________________
264std::vector<Double_t> MethodPyAdaBoost::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
265{
266 // Load model if not already done
267 if (fClassifier == 0) ReadModelFromFile();
268
269 // Determine number of events
270 Long64_t nEvents = Data()->GetNEvents();
271 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
272 if (firstEvt < 0) firstEvt = 0;
273 nEvents = lastEvt-firstEvt;
274
275 // use timer
276 Timer timer( nEvents, GetName(), kTRUE );
277
278 if (logProgress)
279 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
280 << "Evaluation of " << GetMethodName() << " on "
281 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
282 << " sample (" << nEvents << " events)" << Endl;
283
284 // Get data
285 npy_intp dims[2];
286 dims[0] = nEvents;
287 dims[1] = fNvars;
288 PyArrayObject *pEvent= (PyArrayObject *)PyArray_SimpleNew(2, dims, NPY_FLOAT);
289 float *pValue = (float *)(PyArray_DATA(pEvent));
290
291 for (Int_t ievt=0; ievt<nEvents; ievt++) {
292 Data()->SetCurrentEvent(ievt);
293 const TMVA::Event *e = Data()->GetEvent();
294 for (UInt_t i = 0; i < fNvars; i++) {
295 pValue[ievt * fNvars + i] = e->GetValue(i);
296 }
297 }
298
299 // Get prediction from classifier
300 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
301 double *proba = (double *)(PyArray_DATA(result));
302
303 // Return signal probabilities
304 if(Long64_t(mvaValues.size()) != nEvents) mvaValues.resize(nEvents);
305 for (int i = 0; i < nEvents; ++i) {
307 }
308
309 Py_DECREF(pEvent);
310 Py_DECREF(result);
311
312 if (logProgress) {
313 Log() << kINFO
314 << "Elapsed time for evaluation of " << nEvents << " events: "
315 << timer.GetElapsedTime() << " " << Endl;
316 }
317
318 return mvaValues;
319}
320
321//_______________________________________________________________________
323{
324 // cannot determine error
325 NoErrorCalc(errLower, errUpper);
326
327 // Load model if not already done
328 if (fClassifier == 0) ReadModelFromFile();
329
330 // Get current event and load to python array
331 const TMVA::Event *e = Data()->GetEvent();
332 npy_intp dims[2];
333 dims[0] = 1;
334 dims[1] = fNvars;
335 PyArrayObject *pEvent= (PyArrayObject *)PyArray_SimpleNew(2, dims, NPY_FLOAT);
336 float *pValue = (float *)(PyArray_DATA(pEvent));
337 for (UInt_t i = 0; i < fNvars; i++) pValue[i] = e->GetValue(i);
338
339 // Get prediction from classifier
340 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
341 double *proba = (double *)(PyArray_DATA(result));
342
343 // Return MVA value
344 Double_t mvaValue;
345 mvaValue = proba[TMVA::Types::kSignal]; // getting signal probability
346
347 Py_DECREF(result);
348 Py_DECREF(pEvent);
349
350 return mvaValue;
351}
352
353//_______________________________________________________________________
355{
356 // Load model if not already done
357 if (fClassifier == 0) ReadModelFromFile();
358
359 // Get current event and load to python array
360 const TMVA::Event *e = Data()->GetEvent();
361 npy_intp dims[2];
362 dims[0] = 1;
363 dims[1] = fNvars;
364 PyArrayObject *pEvent= (PyArrayObject *)PyArray_SimpleNew(2, dims, NPY_FLOAT);
365 float *pValue = (float *)(PyArray_DATA(pEvent));
366 for (UInt_t i = 0; i < fNvars; i++) pValue[i] = e->GetValue(i);
367
368 // Get prediction from classifier
369 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
370 double *proba = (double *)(PyArray_DATA(result));
371
372 // Return MVA values
373 if(UInt_t(classValues.size()) != fNoutputs) classValues.resize(fNoutputs);
374 for(UInt_t i = 0; i < fNoutputs; i++) classValues[i] = proba[i];
375
376 return classValues;
377}
378
379//_______________________________________________________________________
381{
382 if (!PyIsInitialized()) {
383 PyInitialize();
384 }
385
386 Log() << Endl;
387 Log() << gTools().Color("bold") << "Loading state file: " << gTools().Color("reset") << fFilenameClassifier << Endl;
388 Log() << Endl;
389
390 // Load classifier from file
392 if(err != 0)
393 {
394 Log() << kFATAL << Form("Failed to load classifier from file (error code: %i): %s", err, fFilenameClassifier.Data()) << Endl;
395 }
396
397 // Book classifier object in python dict
398 PyDict_SetItemString(fLocalNS, "classifier", fClassifier);
399
400 // Load data properties
401 // NOTE: This has to be repeated here for the reader application
404}
405
406//_______________________________________________________________________
408{
409 // Get feature importance from classifier as an array with length equal
410 // number of variables, higher value signals a higher importance
411 PyArrayObject* pRanking = (PyArrayObject*) PyObject_GetAttrString(fClassifier, "feature_importances_");
412 // The python object is null if the base estimator does not support
413 // variable ranking. Then, return NULL, which disables ranking.
414 if(pRanking == 0) return NULL;
415
416 // Fill ranking object and return it
417 fRanking = new Ranking(GetName(), "Variable Importance");
418 Double_t* rankingData = (Double_t*) PyArray_DATA(pRanking);
419 for(UInt_t iVar=0; iVar<fNvars; iVar++){
420 fRanking->AddRank(Rank(GetInputLabel(iVar), rankingData[iVar]));
421 }
422
423 Py_DECREF(pRanking);
424
425 return fRanking;
426}
427
428//_______________________________________________________________________
430{
431 // typical length of text line:
432 // "|--------------------------------------------------------------|"
433 Log() << "An AdaBoost classifier is a meta-estimator that begins by fitting" << Endl;
434 Log() << "a classifier on the original dataset and then fits additional copies" << Endl;
435 Log() << "of the classifier on the same dataset but where the weights of incorrectly" << Endl;
436 Log() << "classified instances are adjusted such that subsequent classifiers focus" << Endl;
437 Log() << "more on difficult cases." << Endl;
438 Log() << Endl;
439 Log() << "Check out the scikit-learn documentation for more information." << Endl;
440}
#define REGISTER_METHOD(CLASS)
for example
_object PyObject
#define e(i)
Definition RSha256.hxx:103
unsigned int UInt_t
Definition RtypesCore.h:46
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
long long Long64_t
Definition RtypesCore.h:69
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
#define ClassImp(name)
Definition Rtypes.h:382
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNClasses() const
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
Long64_t GetNTrainingEvents() const
Definition DataSet.h:68
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
const Event * GetTrainingEvent(Long64_t ievt) const
Definition DataSet.h:74
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
const char * GetName() const
Definition MethodBase.h:334
Bool_t IsModelPersistence() const
Definition MethodBase.h:383
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
UInt_t GetNVariables() const
Definition MethodBase.h:345
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
const TString & GetInputLabel(Int_t i) const
Definition MethodBase.h:350
Ranking * fRanking
Definition MethodBase.h:587
DataSet * Data() const
Definition MethodBase.h:409
std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
std::vector< Double_t > mvaValues
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr)
const Ranking * CreateRanking()
std::vector< Float_t > classValues
MethodPyAdaBoost(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
virtual void TestClassification()
initialization
std::vector< Float_t > & GetMulticlassValues()
static int PyIsInitialized()
Check Python interpreter initialization status.
PyObject * Eval(TString code)
Evaluate Python code.
static void PyInitialize()
Initialize Python interpreter.
static void Serialize(TString file, PyObject *classifier)
Serialize Python object.
static Int_t UnSerialize(TString file, PyObject **obj)
Unserialize Python object.
PyObject * fClassifier
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=256)
Execute Python code from string.
Ranking for variables in method (implementation)
Definition Ranking.h:48
virtual void AddRank(const Rank &rank)
Add a new rank take ownership of it.
Definition Ranking.cxx:86
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition Timer.cxx:146
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition Types.h:135
@ kMulticlass
Definition Types.h:129
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:376
Bool_t IsNull() const
Definition TString.h:414
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148