Logo ROOT   6.08/07
Reference Guide
PyMethodBase.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/pymva $Id$
2 // Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : PyMethodBase *
8  * *
9  * Description: *
10  * Virtual base class for all MVA method based on python *
11  * *
12  **********************************************************************************/
13 
14 #include <Python.h> // Needs to be included first to avoid redefinition of _POSIX_C_SOURCE
15 #include <TMVA/PyMethodBase.h>
16 
17 #pragma GCC diagnostic ignored "-Wunused-parameter"
18 #pragma GCC diagnostic ignored "-Wunused-function"
19 
20 #include "TMVA/DataSet.h"
21 #include "TMVA/DataSetInfo.h"
22 #include "TMVA/MsgLogger.h"
23 #include "TMVA/Results.h"
24 #include "TMVA/Timer.h"
25 
26 #include <TApplication.h>
27 
28 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
29 #include <numpy/arrayobject.h>
30 
31 #include <fstream>
32 #include <wchar.h>
33 
34 using namespace TMVA;
35 
37 
41 
45 
49 
50 class PyGILRAII {
51  PyGILState_STATE m_GILState;
52 public:
53  PyGILRAII():m_GILState(PyGILState_Ensure()){}
54  ~PyGILRAII(){PyGILState_Release(m_GILState);}
55 };
56 
57 //_______________________________________________________________________
58 PyMethodBase::PyMethodBase(const TString &jobName,
59  Types::EMVA methodType,
60  const TString &methodTitle,
61  DataSetInfo &dsi,
62  const TString &theOption ): MethodBase(jobName, methodType, methodTitle, dsi, theOption),
63  fClassifier(NULL)
64 {
65  if (!PyIsInitialized()) {
66  PyInitialize();
67  }
68 }
69 
70 //_______________________________________________________________________
72  DataSetInfo &dsi,
73  const TString &weightFile): MethodBase(methodType, dsi, weightFile),
75 {
76  if (!PyIsInitialized()) {
77  PyInitialize();
78  }
79 }
80 
81 //_______________________________________________________________________
83 {
84 }
85 
86 //_______________________________________________________________________
88 {
90  PyObject *pycode = Py_BuildValue("(sOO)", code.Data(), fGlobalNS, fLocalNS);
91  PyObject *result = PyObject_CallObject(fEval, pycode);
92  Py_DECREF(pycode);
93  return result;
94 }
95 
96 //_______________________________________________________________________
98 {
100 
101  bool pyIsInitialized = PyIsInitialized();
102  if (!pyIsInitialized) {
103  Py_Initialize();
104  }
105 
106  PyGILRAII thePyGILRAII;
107 
108  if (!pyIsInitialized) {
109  _import_array();
110  }
111 
112  fMain = PyImport_AddModule("__main__");
113  if (!fMain) {
114  Log << kFATAL << "Can't import __main__" << Endl;
115  Log << Endl;
116  }
117 
118  fGlobalNS = PyModule_GetDict(fMain);
119  if (!fGlobalNS) {
120  Log << kFATAL << "Can't init global namespace" << Endl;
121  Log << Endl;
122  }
123 
124  fLocalNS = PyDict_New();
125  if (!fMain) {
126  Log << kFATAL << "Can't init local namespace" << Endl;
127  Log << Endl;
128  }
129 
130  #if PY_MAJOR_VERSION < 3
131  //preparing objects for eval
132  PyObject *bName = PyUnicode_FromString("__builtin__");
133  // Import the file as a Python module.
134  fModuleBuiltin = PyImport_Import(bName);
135  if (!fModuleBuiltin) {
136  Log << kFATAL << "Can't import __builtin__" << Endl;
137  Log << Endl;
138  }
139  #else
140  //preparing objects for eval
141  PyObject *bName = PyUnicode_FromString("builtins");
142  // Import the file as a Python module.
143  fModuleBuiltin = PyImport_Import(bName);
144  if (!fModuleBuiltin) {
145  Log << kFATAL << "Can't import builtins" << Endl;
146  Log << Endl;
147  }
148  #endif
149 
150  PyObject *mDict = PyModule_GetDict(fModuleBuiltin);
151  fEval = PyDict_GetItemString(mDict, "eval");
152  fOpen = PyDict_GetItemString(mDict, "open");
153 
154  Py_DECREF(bName);
155  Py_DECREF(mDict);
156  //preparing objects for pickle
157  PyObject *pName = PyUnicode_FromString("pickle");
158  // Import the file as a Python module.
159  fModulePickle = PyImport_Import(pName);
160  if (!fModulePickle) {
161  Log << kFATAL << "Can't import pickle" << Endl;
162  Log << Endl;
163  }
164  PyObject *pDict = PyModule_GetDict(fModulePickle);
165  fPickleDumps = PyDict_GetItemString(pDict, "dump");
166  fPickleLoads = PyDict_GetItemString(pDict, "load");
167 
168  Py_DECREF(pName);
169  Py_DECREF(pDict);
170 
171 
172 }
173 
174 //_______________________________________________________________________
176 {
177  Py_Finalize();
178  if (fEval) Py_DECREF(fEval);
179  if (fModuleBuiltin) Py_DECREF(fModuleBuiltin);
180  if (fPickleDumps) Py_DECREF(fPickleDumps);
181  if (fPickleLoads) Py_DECREF(fPickleLoads);
182  if(fMain) Py_DECREF(fMain);//objects fGlobalNS and fLocalNS will be free here
183 }
185 {
186  #if PY_MAJOR_VERSION < 3
187  Py_SetProgramName(const_cast<char*>(name.Data()));
188  #else
189  Py_SetProgramName((wchar_t *)name.Data());
190  #endif
191 }
192 
193 size_t mystrlen(const char* s) { return strlen(s); }
194 size_t mystrlen(const wchar_t* s) { return wcslen(s); }
195 
196 //_______________________________________________________________________
198 {
199 auto progName = ::Py_GetProgramName();
200 return std::string(progName, progName + mystrlen(progName));
201 }
202 //_______________________________________________________________________
204 {
205  if (!Py_IsInitialized()) return kFALSE;
206  if (!fEval) return kFALSE;
207  if (!fModuleBuiltin) return kFALSE;
208  if (!fPickleDumps) return kFALSE;
209  if (!fPickleLoads) return kFALSE;
210  return kTRUE;
211 }
212 
213 void PyMethodBase::Serialize(TString path,PyObject *obj)
214 {
215  if(!PyIsInitialized()) PyInitialize();
216  PyObject *file_arg = Py_BuildValue("(ss)", path.Data(),"wb");
217  PyObject *file = PyObject_CallObject(fOpen,file_arg);
218  PyObject *model_arg = Py_BuildValue("(OO)", obj,file);
219  PyObject *model_data = PyObject_CallObject(fPickleDumps , model_arg);
220 
221  Py_DECREF(file_arg);
222  Py_DECREF(file);
223  Py_DECREF(model_arg);
224  Py_DECREF(model_data);
225 }
226 
227 void PyMethodBase::UnSerialize(TString path,PyObject **obj)
228 {
229  PyObject *file_arg = Py_BuildValue("(ss)", path.Data(),"rb");
230  PyObject *file = PyObject_CallObject(fOpen,file_arg);
231 
232  PyObject *model_arg = Py_BuildValue("(O)", file);
233  *obj = PyObject_CallObject(fPickleLoads , model_arg);
234 
235  Py_DECREF(file_arg);
236  Py_DECREF(file);
237  Py_DECREF(model_arg);
238 }
239 
240 
241 ////////////////////////////////////////////////////////////////////////////////
242 /// get all the MVA values for the events of the current Data type
243 std::vector<Double_t> PyMethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
244 {
245 
247 
249  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
250  if (firstEvt < 0) firstEvt = 0;
251  std::vector<Double_t> values(lastEvt-firstEvt);
252 
253  nEvents = values.size();
254 
255  UInt_t nvars = Data()->GetNVariables();
256 
257  int dims[2];
258  dims[0] = nEvents;
259  dims[1] = nvars;
260  PyArrayObject *pEvent= (PyArrayObject *)PyArray_FromDims(2, dims, NPY_FLOAT);
261  float *pValue = (float *)(PyArray_DATA(pEvent));
262 
263 // int dims2[2];
264 // dims2[0] = 1;
265 // dims2[1] = nvars;
266 
267  // use timer
268  Timer timer( nEvents, GetName(), kTRUE );
269  if (logProgress)
270  Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
271  << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
272 
273 
274  // fill numpy array with events data
275  for (Int_t ievt=0; ievt<nEvents; ievt++) {
276  Data()->SetCurrentEvent(ievt);
277  const TMVA::Event *e = Data()->GetEvent();
278  assert(nvars == e->GetNVariables());
279  for (UInt_t i = 0; i < nvars; i++) {
280  pValue[ievt * nvars + i] = e->GetValue(i);
281  }
282  // if (ievt%100 == 0)
283  // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
284  }
285 
286  // pass all the events to Scikit and evaluate the probabilities
287  PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>("predict_proba"), const_cast<char *>("(O)"), pEvent);
288  double *proba = (double *)(PyArray_DATA(result));
289 
290  // the return probabilities is a vector of pairs of (p_sig,p_backg)
291  // we ar einterested only in the signal probability
292  std::vector<double> mvaValues(nEvents);
293  for (int i = 0; i < nEvents; ++i)
294  mvaValues[i] = proba[2*i];
295 
296  if (logProgress) {
297  Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
298  << timer.GetElapsedTime() << " " << Endl;
299  }
300 
301  Py_DECREF(result);
302  Py_DECREF(pEvent);
303 
304  return mvaValues;
305 }
306 
307 //_______________________________________________________________________
308 // Helper function to run python code from string in local namespace with
309 // error handling
310 // `start` defines the start symbol defined in PyRun_String (Py_eval_input,
311 // Py_single_input, Py_file_input)
312 void PyMethodBase::PyRunString(TString code, TString errorMessage, int start) {
313  fPyReturn = PyRun_String(code, start, fGlobalNS, fLocalNS);
314  if (!fPyReturn) {
315  Log() << kWARNING << "Failed to run python code: " << code << Endl;
316  Log() << kWARNING << "Python error message:" << Endl;
317  PyErr_Print();
318  Log() << kFATAL << errorMessage << Endl;
319  }
320 }
PyMethodBase(const TString &jobName, Types::EMVA methodType, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:113
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
long long Long64_t
Definition: RtypesCore.h:69
PyObject * fClassifier
Definition: PyMethodBase.h:121
PyObject * fPyReturn
Definition: PyMethodBase.h:127
static PyObject * fModulePickle
Definition: PyMethodBase.h:138
MsgLogger & Log() const
Definition: Configurable.h:128
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:225
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
static void Serialize(TString file, PyObject *classifier)
TString GetElapsedTime(Bool_t Scientific=kTRUE)
Definition: Timer.cxx:129
static int PyIsInitialized()
static void PyInitialize()
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=Py_single_input)
TStopwatch timer
Definition: pirndm.C:37
virtual void ReadModelFromFile()=0
DataSet * Data() const
Definition: MethodBase.h:405
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:217
static PyObject * fEval
Definition: PyMethodBase.h:134
static PyObject * Eval(TString code)
DataSetInfo & DataInfo() const
Definition: MethodBase.h:406
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
const int nEvents
Definition: testRooFit.cxx:42
static void PyFinalize()
static TString Py_GetProgramName()
const char * GetName() const
Definition: MethodBase.h:330
static void PySetProgramName(TString name)
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
const TString & GetMethodName() const
Definition: MethodBase.h:327
static PyObject * fOpen
Definition: PyMethodBase.h:135
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:305
Float_t GetValue(UInt_t ivar) const
return value of i&#39;th variable
Definition: Event.cxx:233
size_t mystrlen(const char *s)
static PyObject * fLocalNS
Definition: PyMethodBase.h:144
static PyObject * fModuleBuiltin
Definition: PyMethodBase.h:133
#define ClassImp(name)
Definition: Rtypes.h:279
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
Abstract ClassifierFactory template that handles arbitrary types.
static PyObject * fPickleLoads
Definition: PyMethodBase.h:140
static PyObject * fPickleDumps
Definition: PyMethodBase.h:139
Definition: file.py:1
virtual ~PyMethodBase()
#define NULL
Definition: Rtypes.h:82
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:229
double result[121]
static PyObject * fMain
Definition: PyMethodBase.h:142
static void UnSerialize(TString file, PyObject **obj)
const Bool_t kTRUE
Definition: Rtypes.h:91
const Event * GetEvent() const
Definition: DataSet.cxx:211
char name[80]
Definition: TGX11.cxx:109
static PyObject * fGlobalNS
Definition: PyMethodBase.h:143
_object PyObject
Definition: TPyArg.h:22