1// @(#)root/tmva/pymva $Id$
2// Authors: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer 2015, Stefan Wunsch 2017
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : PyMethodBase *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * Virtual base class for all MVA method based on Python *
12 * *
13 **********************************************************************************/
15#ifndef ROOT_TMVA_PyMethodBase
16#define ROOT_TMVA_PyMethodBase
19// //
20// PyMethodBase //
21// //
22// Virtual base class for all TMVA method based on Python //
23// //
26#include "TMVA/MethodBase.h"
27#include "TMVA/Types.h"
29#include "Rtypes.h"
30#include "TString.h"
31#include <vector>
33class TFile;
34class TGraph;
35class TTree;
36class TDirectory;
37class TSpline;
38class TH1F;
39class TH1D;
41#ifndef PyObject_HEAD
42struct _object;
43typedef _object PyObject;
44#define Py_single_input 256
47namespace TMVA {
49 class Ranking;
50 class PDF;
51 class TSpline1;
52 class MethodCuts;
53 class MethodBoost;
54 class DataSetInfo;
56 /// Function to find current Python executable
57 /// used by ROOT
58 /// If Python2 is installed return "python"
59 /// Instead if "Python3" return "python3"
62 class PyMethodBase : public MethodBase {
64 friend class Factory;
65 public:
67 // default constructur
68 PyMethodBase(const TString &jobName,
69 Types::EMVA methodType,
70 const TString &methodTitle,
71 DataSetInfo &dsi,
72 const TString &theOption = "");
74 // constructor used for Testing + Application of the MVA, only (no training),
75 // using given weight file
76 PyMethodBase(Types::EMVA methodType,
77 DataSetInfo &dsi,
78 const TString &weightFile);
80 // default destructur
81 virtual ~PyMethodBase();
82 //basic python related function
83 static void PyInitialize();
84 static int PyIsInitialized();
85 static void PyFinalize();
86 static void PySetProgramName(TString name);
89 PyObject *Eval(TString code); // required to parse booking options from string to pyobjects
90 static void Serialize(TString file,PyObject *classifier);
93 virtual void Train() = 0;
94 // options treatment
95 virtual void Init() = 0;
96 virtual void DeclareOptions() = 0;
97 virtual void ProcessOptions() = 0;
98 // create ranking
99 virtual const Ranking *CreateRanking() = 0;
101 virtual Double_t GetMvaValue(Double_t *errLower = nullptr, Double_t *errUpper = nullptr) = 0;
103 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets) = 0;
104 protected:
105 // the actual "weights"
106 virtual void AddWeightsXMLTo(void *parent) const = 0;
107 virtual void ReadWeightsFromXML(void *wghtnode) = 0;
108 virtual void ReadWeightsFromStream(std::istream &) = 0; // backward compatibility
109 virtual void ReadWeightsFromStream(TFile &) {} // backward compatibility
111 virtual void ReadModelFromFile() = 0;
113 // signal/background classification response for all current set of data
114 virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false) = 0;
116 protected:
117 PyObject *fModule; // Module to load
118 PyObject *fClassifier; // Classifier object
120 PyObject *fPyReturn; // python return data
122 protected:
123 void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=Py_single_input); // runs python code from string in local namespace with error handling
125 private:
127 static PyObject *fEval; // eval funtion from python
128 static PyObject *fOpen; // open function for files
130 protected:
131 static PyObject *fModulePickle; // Module for model persistence
132 static PyObject *fPickleDumps; // Function to dumps PyObject information into string
133 static PyObject *fPickleLoads; // Function to load PyObject information from string
135 static PyObject *fMain; // module __main__ to get namespace local and global
136 static PyObject *fGlobalNS; // global namesapace
137 PyObject *fLocalNS; // local namesapace
139 public:
140 static void PyRunString(TString code, PyObject *globalNS, PyObject* localNS); // Overloaded static Python utlity function for running Python code
141 static const char* PyStringAsString(PyObject *string); // Python Utility function for converting a Python String object to const char*
142 static std::vector<size_t> GetDataFromTuple(PyObject *tupleObject); // Function casts Python Tuple object into vector of size_t
143 static std::vector<size_t> GetDataFromList(PyObject *listObject); // Function casts Python List object into vector of size_t
144 static PyObject* GetValueFromDict(PyObject* dict, const char* key); // Function to check for a key in dict and return the associated value if present
145 ClassDef(PyMethodBase, 0) // Virtual base class for all TMVA method
147 };
149} // namespace TMVA
