Logo ROOT  
Reference Guide
RModelParser_Keras.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Sanjiban Sengupta 2021
3
4/**********************************************************************************
5 * Project : TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package : TMVA *
7 * Function: TMVA::Experimental::SOFIE::PyKeras::Parse *
8 * *
9 * Description: *
10 * Parser function for translating Keras .h5 model to RModel object *
11 * *
12 * Example Usage: *
13 * ~~~ {.cpp} *
14 * using TMVA::Experimental::SOFIE; *
15 * RModel model = PyKeras::Parse("trained_model_dense.h5"); *
16 * ~~~ *
17 * *
18 **********************************************************************************/
19
21
22#include <Python.h>
23
24#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
25#include <numpy/arrayobject.h>
26
27namespace TMVA{
28namespace Experimental{
29namespace SOFIE{
30namespace PyKeras{
31
32// Referencing Python utility functions present in PyMethodBase
35static std::vector<size_t>(& GetDataFromTuple)(PyObject*) = PyMethodBase::GetDataFromTuple;
36
37namespace INTERNAL{
38
39// For adding Keras layer into RModel object
40void AddKerasLayer(RModel &rmodel, PyObject *fLayer);
41
42// Declaring Internal Functions for Keras layers which don't have activation as an additional attribute
43std::unique_ptr<ROperator> MakeKerasActivation(PyObject *fLayer); // For instantiating ROperator for Keras Activation Layer
44std::unique_ptr<ROperator> MakeKerasReLU(PyObject *fLayer); // For instantiating ROperator for Keras ReLU layer
45std::unique_ptr<ROperator> MakeKerasSelu(PyObject *fLayer); // For instantiating ROperator for Keras Selu layer
46std::unique_ptr<ROperator> MakeKerasSigmoid(PyObject *fLayer); // For instantiating ROperator for Keras Sigmoid layer
47std::unique_ptr<ROperator> MakeKerasPermute(PyObject *fLayer); // For instantiating ROperator for Keras Permute Layer
48
49// Declaring Internal function for Keras layers which have additional activation attribute
50std::unique_ptr<ROperator> MakeKerasDense(PyObject *fLayer); // For instantiating ROperator for Keras Dense Layer
51
52// For mapping Keras layer with the preparatory functions for ROperators
53using KerasMethodMap = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(PyObject *fLayer)>;
54using KerasMethodMapWithActivation = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(PyObject *fLayer)>;
55
57 {"Activation", &MakeKerasActivation},
58 {"Permute", &MakeKerasPermute},
59
60 // For activation layers
61 {"ReLU", &MakeKerasReLU},
62
63 // For layers with activation attributes
64 {"relu", &MakeKerasReLU},
65 {"selu", &MakeKerasSelu},
66 {"sigmoid", &MakeKerasSigmoid}
67};
68
70 {"Dense", &MakeKerasDense},
71 };
72
73
74//////////////////////////////////////////////////////////////////////////////////
75/// \brief Adds equivalent ROperator with respect to Keras model layer
76/// into the referenced RModel object
77///
78/// \param[inout] rmodel RModel object, by reference, returned ith the added ROperator
79/// \param[in] fLayer Python Keras layer as a Dictionary object
80///
81/// Function adds equivalent ROperator into the referenced RModel object.
82/// Keras models can have layers like Dense and Conv which have activation
83/// function as an attribute. Function first searches if layer object is among
84/// the ones which don't have activation attribute and then calls the respective
85/// preparation function to get the ROperator object, which is then added
86/// into the RModel object. If passed layer is among the ones which may have activation
87/// attribute, then it checks for the activation attribute, if present then first adds
88/// the primary operator into the RModel object, and then adds the operator for the
89/// activation function with appropriate changes in the names of input and output
90/// tensors for both of them.
91/// Example of such layers is the Dense Layer. For a dense layer with input tensor name
92/// dense2BiasAdd0 and output tensor name dense3Relu0 with relu as activation attribute
93/// will be transformed into a ROperator_Gemm with input tensor name dense2BiasAdd0
94/// & output tensor name dense3Dense (layerName+layerType), and a subsequent
95/// ROperator_Relu with input tensor name as dense3Dense and output tensor name
96/// as dense3Relu0.
97///
98/// For developing new preparatory functions for supporting Keras layers in future,
99/// all one needs is to extract the required properties and attributes from the fLayer
100/// dictionary which contains all the information about any Keras layer and after
101/// any required transformations, these are passed for instantiating the ROperator
102/// object.
103///
104/// The fLayer dictionary which holds all the information about a Keras layer has
105/// following structure:-
106///
107/// dict fLayer { 'layerType' : Type of the Keras layer
108/// 'layerAttributes' : Attributes of the keras layer as returned by layer.get_config()
109/// 'layerInput' : List of names of input tensors
110/// 'layerOutput' : List of names of output tensors
111/// 'layerDType' : Data-type of the Keras layer
112/// 'layerWeight' : List of weight tensor names of Keras layers
113/// }
114void AddKerasLayer(RModel& rmodel, PyObject* fLayer){
115 std::string fLayerType = PyStringAsString(PyDict_GetItemString(fLayer,"layerType"));
116
117 //For layers without additional activation attribute
118 auto findLayer = mapKerasLayer.find(fLayerType);
119 if(findLayer != mapKerasLayer.end()){
120 rmodel.AddOperator((findLayer->second)(fLayer));
121 return;
122 }
123
124 //For layers like Dense & Conv which has additional activation attribute
125 else if(mapKerasLayerWithActivation.find(fLayerType) != mapKerasLayerWithActivation.end()){
126 findLayer = mapKerasLayerWithActivation.find(fLayerType);
127 PyObject* fAttributes=PyDict_GetItemString(fLayer,"layerAttributes");
128
129 std::string fLayerName = PyStringAsString(PyDict_GetItemString(fAttributes,"name"));
130 std::string fLayerActivation = PyStringAsString(PyDict_GetItemString(fAttributes,"activation"));
131
132 if(fLayerActivation == "selu" || fLayerActivation == "sigmoid")
133 rmodel.AddNeededStdLib("cmath");
134
135 //Checking if additional attribute exixts
136 if(fLayerActivation != "linear"){
137 PyObject* fOutputs = PyDict_GetItemString(fLayer,"layerOutput");
138 PyObject* fInputs = PyDict_GetItemString(fLayer,"layerInput");
139 std::string fActivationLayerOutput = PyStringAsString(PyList_GetItem(fOutputs,0));
140
141 // Making changes in the names of the input and output tensor names
142 PyList_SetItem(fOutputs,0,PyUnicode_FromString((fLayerName+fLayerType).c_str()));
143 PyDict_SetItemString(fLayer,"layerOutput",fOutputs);
144 rmodel.AddOperator((findLayer->second)(fLayer));
145
146 std::string fActivationLayerInput = PyStringAsString(PyList_GetItem(fOutputs,0));
147 PyList_SetItem(fInputs,0,PyUnicode_FromString(fActivationLayerInput.c_str()));
148 PyList_SetItem(fOutputs,0,PyUnicode_FromString(fActivationLayerOutput.c_str()));
149 PyDict_SetItemString(fLayer,"layerInput",fInputs);
150 PyDict_SetItemString(fLayer,"layerOutput",fOutputs);
151
152 auto findActivationLayer = mapKerasLayer.find(fLayerActivation);
153 if(findActivationLayer == mapKerasLayer.end()){
154 throw std::runtime_error("TMVA::SOFIE - Parsing Keras Activation layer " + fLayerActivation + " is not yet supported");
155 }
156 rmodel.AddOperator((findActivationLayer->second)(fLayer));
157 }
158 else{
159 rmodel.AddOperator((findLayer->second)(fLayer));
160 }
161 return;
162 }
163
164 else{
165 throw std::runtime_error("TMVA::SOFIE - Parsing Keras layer " + fLayerType + " is not yet supported");
166 }
167
168}
169
170//////////////////////////////////////////////////////////////////////////////////
171/// \brief Prepares a ROperator object for Keras Dense Layer
172///
173/// \param[in] fLayer Python Keras layer as a Dictionary object
174/// \return Unique pointer to ROperator object
175///
176/// For Keras's Dense layer, the names of the input tensor, output tensor, and
177/// weight tensors are extracted, and then are passed to instantiate a
178/// ROperator_Gemm object using the required attributes.
179std::unique_ptr<ROperator> MakeKerasDense(PyObject* fLayer){
180 PyObject* fInputs = PyDict_GetItemString(fLayer,"layerInput");
181 PyObject* fOutputs = PyDict_GetItemString(fLayer,"layerOutput");
182 std::string fLayerDType = PyStringAsString(PyDict_GetItemString(fLayer,"layerDType"));
183
184 std::string fLayerInputName = PyStringAsString(PyList_GetItem(fInputs,0));
185 std::string fLayerOutputName = PyStringAsString(PyList_GetItem(fOutputs,0));
186
187 // Extracting names of weight tensors
188 // The names of Kernel weights and bias weights are found in the list
189 // of weight tensors from fLayer.
190 PyObject* fWeightNames = PyDict_GetItemString(fLayer,"layerWeight");
191 std::string fKernelName = PyStringAsString(PyList_GetItem(fWeightNames,0));
192 std::string fBiasName = PyStringAsString(PyList_GetItem(fWeightNames,1));
193
194 std::unique_ptr<ROperator> op;
195
196 float attr_alpha = 1.0;
197 float attr_beta = 1.0;
198 int_t attr_transA = 0;
199 int_t attr_transB = 0;
200
201 switch(ConvertStringToType(fLayerDType)){
203 op.reset(new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, fLayerInputName, fKernelName, fBiasName, fLayerOutputName));
204 break;
205
206 default:
207 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Gemm does not yet support input type " + fLayerDType);
208 }
209 return op;
210}
211
212
213//////////////////////////////////////////////////////////////////////////////////
214/// \brief Prepares a ROperator object for Keras activation layer
215///
216/// \param[in] fLayer Python Keras layer as a Dictionary object
217/// \return Unique pointer to ROperator object
218///
219/// For Keras's keras.layers.Activation layer, the activation attribute is
220/// extracted and appropriate function for adding the function is called.
221std::unique_ptr<ROperator> MakeKerasActivation(PyObject* fLayer){
222 PyObject* fAttributes=PyDict_GetItemString(fLayer,"layerAttributes");
223 std::string fLayerActivation = PyStringAsString(PyDict_GetItemString(fAttributes,"activation"));
224
225 auto findLayer = mapKerasLayer.find(fLayerActivation);
226 if(findLayer == mapKerasLayer.end()){
227 throw std::runtime_error("TMVA::SOFIE - Parsing Keras Activation layer " + fLayerActivation + " is not yet supported");
228 }
229 return (findLayer->second)(fLayer);
230}
231
232
233//////////////////////////////////////////////////////////////////////////////////
234/// \brief Prepares a ROperator object for Keras ReLU activation
235///
236/// \param[in] fLayer Python Keras layer as a Dictionary object
237/// \return Unique pointer to ROperator object
238///
239/// For instantiating a ROperator_Relu object, the names of
240/// input & output tensors and the deta-type of the layer are extracted.
241std::unique_ptr<ROperator> MakeKerasReLU(PyObject* fLayer)
242{
243 PyObject* fInputs=PyDict_GetItemString(fLayer,"layerInput");
244 PyObject* fOutputs=PyDict_GetItemString(fLayer,"layerOutput");
245
246 std::string fLayerDType = PyStringAsString(PyDict_GetItemString(fLayer,"layerDType"));
247 std::string fLayerInputName = PyStringAsString(PyList_GetItem(fInputs,0));
248 std::string fLayerOutputName = PyStringAsString(PyList_GetItem(fOutputs,0));
249
250 std::unique_ptr<ROperator> op;
251 switch(ConvertStringToType(fLayerDType)){
253 op.reset(new ROperator_Relu<float>(fLayerInputName, fLayerOutputName));
254 break;
255 default:
256 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Relu does not yet support input type " + fLayerDType);
257 }
258 return op;
259}
260
261
262//////////////////////////////////////////////////////////////////////////////////
263/// \brief Prepares a ROperator object for Keras Selu activation
264///
265/// \param[in] fLayer Python Keras layer as a Dictionary object
266/// \return Unique pointer to ROperator object
267///
268/// For instantiating a ROperator_Selu object, the names of
269/// input & output tensors and the deta-type of the layer are extracted.
270std::unique_ptr<ROperator> MakeKerasSelu(PyObject* fLayer){
271 PyObject* fInputs = PyDict_GetItemString(fLayer,"layerInput");
272 PyObject* fOutputs = PyDict_GetItemString(fLayer,"layerOutput");
273
274 std::string fLayerDType = PyStringAsString(PyDict_GetItemString(fLayer,"layerDType"));
275 std::string fLayerInputName = PyStringAsString(PyList_GetItem(fInputs,0));
276 std::string fLayerOutputName = PyStringAsString(PyList_GetItem(fOutputs,0));
277
278 std::unique_ptr<ROperator> op;
279 switch(ConvertStringToType(fLayerDType)){
281 op.reset(new ROperator_Selu<float>(fLayerInputName, fLayerOutputName));
282 break;
283 default:
284 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Selu does not yet support input type " + fLayerDType);
285 }
286 return op;
287}
288
289
290//////////////////////////////////////////////////////////////////////////////////
291/// \brief Prepares a ROperator object for Keras Sigmoid activation
292///
293/// \param[in] fLayer Python Keras layer as a Dictionary object
294/// \return Unique pointer to ROperator object
295///
296/// For instantiating a ROperator_Sigmoid object, the names of
297/// input & output tensors and the deta-type of the layer are extracted.
298std::unique_ptr<ROperator> MakeKerasSigmoid(PyObject* fLayer){
299 PyObject* fInputs = PyDict_GetItemString(fLayer,"layerInput");
300 PyObject* fOutputs = PyDict_GetItemString(fLayer,"layerOutput");
301
302 std::string fLayerDType = PyStringAsString(PyDict_GetItemString(fLayer,"layerDType"));
303 std::string fLayerInputName = PyStringAsString(PyList_GetItem(fInputs,0));
304 std::string fLayerOutputName = PyStringAsString(PyList_GetItem(fOutputs,0));
305
306 std::unique_ptr<ROperator> op;
307 switch(ConvertStringToType(fLayerDType)){
309 op.reset(new ROperator_Sigmoid<float>(fLayerInputName, fLayerOutputName));
310 break;
311 default:
312 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Sigmoid does not yet support input type " + fLayerDType);
313 }
314 return op;
315}
316
317
318//////////////////////////////////////////////////////////////////////////////////
319/// \brief Prepares a ROperator object for Keras Permute layer
320///
321/// \param[in] fLayer Python Keras layer as a Dictionary object
322/// \return Unique pointer to ROperator object
323///
324/// The Permute layer in Keras has an equivalent Tranpose operator in ONNX.
325/// For adding a Transpose operator, the permute dimensions are found, if they
326/// exist are passed in instantiating the ROperator, else default values are used.
327std::unique_ptr<ROperator> MakeKerasPermute(PyObject* fLayer)
328{
329 // Extracting required layer information
330 PyObject* fAttributes=PyDict_GetItemString(fLayer,"layerAttributes");
331 PyObject* fInputs=PyDict_GetItemString(fLayer,"layerInput");
332 PyObject* fOutputs=PyDict_GetItemString(fLayer,"layerOutput");
333
334 std::string fLayerDType = PyStringAsString(PyDict_GetItemString(fLayer,"layerDType"));
335 std::string fLayerInputName = PyStringAsString(PyList_GetItem(fInputs,0));
336 std::string fLayerOutputName = PyStringAsString(PyList_GetItem(fOutputs,0));
337
338 // Extracting the permute dimensions present in Attributes of the Keras layer
339 PyObject* fAttributePermute=PyDict_GetItemString(fAttributes,"dims");
340 std::vector<int_t>fPermuteDims;
341
342 // Building vector of permute dimensions from the Tuple object.
343 for(Py_ssize_t tupleIter=0;tupleIter<PyTuple_Size(fAttributePermute);++tupleIter){
344
345 fPermuteDims.push_back((int_t)PyLong_AsLong(PyTuple_GetItem(fAttributePermute,tupleIter)));
346 }
347 std::unique_ptr<ROperator> op;
348 switch(ConvertStringToType(fLayerDType)){
349 case ETensorType::FLOAT:{
350
351 // Adding the permute dimensions if present, else are avoided to use default values.
352 if (!fPermuteDims.empty()){
353 op.reset(new ROperator_Transpose<float>(fPermuteDims, fLayerInputName, fLayerOutputName));
354 }
355 else{
356 op.reset(new ROperator_Transpose<float> (fLayerInputName, fLayerOutputName));
357 }
358 break;
359 }
360 default:
361 throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Transpose does not yet support input type " + fLayerDType);
362 }
363 return op;
364 }
365
366}//INTERNAL
367
368
369//////////////////////////////////////////////////////////////////////////////////
370/// \param[in] filename file location of Keras .h5
371/// \return Parsed RModel object
372///
373/// The `Parse()` function defined in `TMVA::Experimental::SOFIE::PyKeras` will
374/// parse a trained Keras .h5 model into a RModel Object. After loading the model
375/// in a Python Session, the included layers are extracted with properties
376/// like Layer type, Attributes, Input tensor names, Output tensor names, data-type
377/// and names of the weight/initialized tensors.
378/// The extracted layers from the model are then passed into `AddKerasLayer()`
379/// which prepares the specific ROperator and adds them into the RModel object.
380/// The layers are also checked for adding any required routines for executing
381/// the generated Inference code.
382///
383/// For adding the Initialized tensors into the RModel object, the weights are
384/// extracted from the Keras model in the form of NumPy arrays, which are then
385/// passed into `AddInitializedTensor()` after appropriate casting.
386///
387/// Input tensor infos are required to be added which will contain their names,
388/// shapes and data-types. For keras models with single input tensors, the tensor
389/// shape is returned as a Tuple object, whereas for multi-input models,
390/// the tensor shape is returned as a List of Tuple object containing the shape
391/// of the individual input tensors. SOFIE's RModel also requires that the Keras
392/// models are initialized with Batch Size. The `GetDataFromTuple()` are called
393/// on the Tuple objects, which then returns the shape vector required to call
394/// the `AddInputTensorInfo()`.
395///
396/// For adding the Output Tensor infos, only the names of the model's output
397/// tensors are extracted and are then passed into `AddOutputTensorNameList()`.
398///
399/// Example Usage:
400/// ~~~ {.cpp}
401/// using TMVA::Experimental::SOFIE;
402/// RModel model = PyKeras::Parse("trained_model_dense.h5");
403/// ~~~
404RModel Parse(std::string filename){
405
406 char sep = '/';
407 #ifdef _WIN32
408 sep = '\\';
409 #endif
410
411 size_t isep = filename.rfind(sep, filename.length());
412 std::string filename_nodir = filename;
413 if (isep != std::string::npos){
414 filename_nodir = (filename.substr(isep+1, filename.length() - isep));
415 }
416
417 //Check on whether the Keras .h5 file exists
418 if(!std::ifstream(filename).good()){
419 throw std::runtime_error("Model file "+filename_nodir+" not found!");
420 }
421
422
423 std::time_t ttime = std::time(0);
424 std::tm* gmt_time = std::gmtime(&ttime);
425 std::string parsetime (std::asctime(gmt_time));
426
427 RModel rmodel(filename_nodir, parsetime);
428
429 //Intializing Python Interpreter and scope dictionaries
430 Py_Initialize();
431 PyObject* main = PyImport_AddModule("__main__");
432 PyObject* fGlobalNS = PyModule_GetDict(main);
433 PyObject* fLocalNS = PyDict_New();
434 if (!fGlobalNS) {
435 throw std::runtime_error("Can't init global namespace for Python");
436 }
437 if (!fLocalNS) {
438 throw std::runtime_error("Can't init local namespace for Python");
439 }
440
441 // Extracting model information
442 // For each layer: type,name,activation,dtype,input tensor's name,
443 // output tensor's name, kernel's name, bias's name
444 // None object is returned for if property doesn't belong to layer
445 PyRunString("import tensorflow.keras as keras",fGlobalNS,fLocalNS);
446 PyRunString("from tensorflow.keras.models import load_model",fGlobalNS,fLocalNS);
447 PyRunString("print('Keras Version: '+ keras.__version__)",fGlobalNS,fLocalNS);
448 PyRunString(TString::Format("model=load_model('%s')",filename.c_str()),fGlobalNS,fLocalNS);
449 PyRunString(TString::Format("model.load_weights('%s')",filename.c_str()),fGlobalNS,fLocalNS);
450 PyRunString("globals().update(locals())",fGlobalNS,fLocalNS);
451 PyRunString("modelData=[]",fGlobalNS,fLocalNS);
452 PyRunString("for idx in range(len(model.layers)):\n"
453 " layer=model.get_layer(index=idx)\n"
454 " globals().update(locals())\n"
455 " layerData={}\n"
456 " layerData['layerType']=layer.__class__.__name__\n"
457 " layerData['layerAttributes']=layer.get_config()\n"
458 " layerData['layerInput']=[x.name for x in layer.input] if isinstance(layer.input,list) else [layer.input.name]\n"
459 " layerData['layerOutput']=[x.name for x in layer.output] if isinstance(layer.output,list) else [layer.output.name]\n"
460 " layerData['layerDType']=layer.dtype\n"
461 " layerData['layerWeight']=[x.name for x in layer.weights]\n"
462 " modelData.append(layerData)",fGlobalNS,fLocalNS);
463
464
465 PyObject* fPModel = PyDict_GetItemString(fLocalNS,"modelData");
466 PyObject *fLayer;
467 Py_ssize_t fModelSize = PyList_Size(fPModel);
468 std::string fLayerType;
469
470 // Traversing through all the layers and passing the Layer object to `AddKerasLayer()`
471 // for adding the equivalent ROperators into the RModel object.
472 for(Py_ssize_t fModelIterator=0;fModelIterator<fModelSize;++fModelIterator){
473 fLayer = PyList_GetItem(fPModel,fModelIterator);
474 fLayerType = PyStringAsString(PyDict_GetItemString(fLayer,"layerType"));
475
476 // Ignoring the input layer for models built using Keras Functional API
477 if(fLayerType == "InputLayer")
478 continue;
479
480 // Adding any required routines depending on the Layer types for generating
481 // inference code.
482 else if(fLayerType == "Dense")
483 rmodel.AddBlasRoutines({"Gemm", "Gemv"});
484 INTERNAL::AddKerasLayer(rmodel,fLayer);
485
486 }
487
488 //Extracting model's weights
489 //For every initialized tensor, weightProp will have its name and dtype in string
490 //and value in numpy array
491 PyRunString("globals().update(locals())",fGlobalNS,fLocalNS);
492 PyRunString("weight=[]",fGlobalNS,fLocalNS);
493 PyRunString("for idx in range(len(model.get_weights())):\n"
494 " weightProp={}\n"
495 " weightProp['name']=model.weights[idx].name\n"
496 " weightProp['dtype']=(model.get_weights())[idx].dtype.name\n"
497 " weightProp['value']=(model.get_weights())[idx]\n"
498 " weight.append(weightProp)",fGlobalNS,fLocalNS);
499
500 PyObject *fWeightTensor, *fPWeight;
501 PyArrayObject *fWeightTensorValue;
502 std::string fWeightName;
503 ETensorType fWeightDType;
504 fPWeight = PyDict_GetItemString(fLocalNS,"weight");
505 std::vector<std::size_t> fWeightTensorShape;
506 std::size_t fWeightTensorSize;
507
508 // Traversing through all the Weight tensors
509 for (Py_ssize_t weightIter = 0; weightIter < PyList_Size(fPWeight); weightIter++){
510 fWeightTensor = PyList_GetItem(fPWeight, weightIter);
511 fWeightName = PyStringAsString(PyDict_GetItemString(fWeightTensor,"name"));
512 fWeightDType = ConvertStringToType(PyStringAsString(PyDict_GetItemString(fWeightTensor,"dtype")));
513
514 fWeightTensorValue = (PyArrayObject*)PyDict_GetItemString(fWeightTensor,"value");
515 fWeightTensorSize=1;
516 fWeightTensorShape.clear();
517
518 // Building the shape vector and finding the tensor size
519 for(int j=0; j<PyArray_NDIM(fWeightTensorValue); ++j){
520 fWeightTensorShape.push_back((std::size_t)(PyArray_DIM(fWeightTensorValue,j)));
521 fWeightTensorSize*=(std::size_t)(PyArray_DIM(fWeightTensorValue,j));
522 }
523
524 switch(fWeightDType){
525 case ETensorType::FLOAT : {
526 float* fWeightArray = (float*)PyArray_DATA(fWeightTensorValue);
527 std::shared_ptr<void> fData(malloc(fWeightTensorSize * sizeof(float)), free);
528 std::memcpy(fData.get(),fWeightArray, fWeightTensorSize * sizeof(float));
529 rmodel.AddInitializedTensor(fWeightName,ETensorType::FLOAT,fWeightTensorShape,fData);
530 break;
531 }
532 default:
533 throw std::runtime_error("Type error: TMVA SOFIE does not yet weight data layer type"+ConvertTypeToString(fWeightDType));
534 }
535 }
536
537
538 // Extracting input tensor info
539 // For every input tensor inputNames will have their names as string,
540 // inputShapes will have their shape as Python Tuple, and inputTypes
541 // will have their dtype as string
542 PyRunString("inputNames=model.input_names",fGlobalNS,fLocalNS);
543 PyRunString("inputShapes=model.input_shape",fGlobalNS,fLocalNS);
544 PyRunString("inputTypes=[]",fGlobalNS,fLocalNS);
545 PyRunString("for idx in range(len(model.inputs)):\n"
546 " inputTypes.append(model.inputs[idx].dtype.__str__()[9:-2])",fGlobalNS,fLocalNS);
547
548 PyObject* fPInputs = PyDict_GetItemString(fLocalNS,"inputNames");
549 PyObject* fPInputShapes = PyDict_GetItemString(fLocalNS,"inputShapes");
550 PyObject* fPInputTypes = PyDict_GetItemString(fLocalNS,"inputTypes");
551
552 std::string fInputName;
553 ETensorType fInputDType;
554
555 // For single input models, the model.input_shape will return a tuple
556 // describing the input tensor shape. For multiple inputs models,
557 // the model.input_shape will return a list of tuple, each describing
558 // the input tensor shape.
559 if(PyTuple_Check(fPInputShapes)){
560 fInputName = PyStringAsString(PyList_GetItem(fPInputs,0));
561 fInputDType = ConvertStringToType(PyStringAsString(PyList_GetItem(fPInputTypes,0)));
562
563 switch(fInputDType){
564
565 case ETensorType::FLOAT : {
566 if (PyTuple_GetItem(fPInputShapes,0) == Py_None){
567 throw std::runtime_error("None error: Models not initialized with batch-size are not yet supported in TMVA SOFIE");
568 }
569
570 // Getting the shape vector from the Tuple object
571 std::vector<size_t>fInputShape = GetDataFromTuple(fPInputShapes);
572 rmodel.AddInputTensorInfo(fInputName, ETensorType::FLOAT, fInputShape);
573 break;
574 }
575
576 default:
577 throw std::runtime_error("Type error: TMVA SOFIE does not yet suppport data type"+ConvertTypeToString(fInputDType));
578 }
579
580 }
581
582 else{
583
584 // Iterating through multiple input tensors
585 for(Py_ssize_t inputIter = 0; inputIter < PyList_Size(fPInputs);++inputIter){
586
587 fInputName = PyStringAsString(PyList_GetItem(fPInputs,inputIter));
588 fInputDType = ConvertStringToType(PyStringAsString(PyList_GetItem(fPInputTypes,inputIter)));
589
590 switch(fInputDType){
591 case ETensorType::FLOAT : {
592 PyObject* fInputShapeTuple=PyList_GetItem(fPInputShapes,inputIter);
593
594 if (PyTuple_GetItem(fInputShapeTuple,0) == Py_None){
595 throw std::runtime_error("None error: Models not initialized with batch-size are not yet supported in TMVA SOFIE");
596 }
597
598 std::vector<size_t>fInputShape = GetDataFromTuple(fInputShapeTuple);
599 rmodel.AddInputTensorInfo(fInputName, ETensorType::FLOAT, fInputShape);
600 break;
601 }
602
603 default:
604 throw std::runtime_error("Type error: TMVA SOFIE does not yet suppport data type"+ConvertTypeToString(fInputDType));
605
606 }
607 }
608 }
609
610 // For adding OutputTensorInfos, the names of the output
611 // tensors are extracted from the Keras model
612 PyRunString("outputNames=[]",fGlobalNS,fLocalNS);
613 PyRunString("for layerName in model.output_names:\n"
614 " outputNames.append(model.get_layer(layerName).output.name)",fGlobalNS,fLocalNS);
615 PyObject* fPOutputs = PyDict_GetItemString(fLocalNS,"outputNames");
616 std::vector<std::string> fOutputNames;
617 for(Py_ssize_t outputIter = 0; outputIter < PyList_Size(fPOutputs);++outputIter){
618 fOutputNames.push_back(PyStringAsString(PyList_GetItem(fPOutputs,outputIter)));
619 }
620 rmodel.AddOutputTensorNameList(fOutputNames);
621
622 return rmodel;
623}
624}//PyKeras
625}//SOFIE
626}//Experimental
627}//TMVA
typedef void(GLAPIENTRYP _GLUfuncptr)(void)
_object PyObject
Definition: PyMethodBase.h:42
int main(int argc, char *argv[])
Definition: cef_main.cxx:54
#define free
Definition: civetweb.c:1539
#define malloc
Definition: civetweb.c:1536
void AddOutputTensorNameList(std::vector< std::string > outputtensornames)
Definition: RModel.cxx:143
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector< Dim > shape)
Definition: RModel.cxx:96
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector< std::size_t > shape, std::shared_ptr< void > data)
Definition: RModel.cxx:123
void AddBlasRoutines(std::vector< std::string > routines)
Definition: RModel.hxx:72
void AddNeededStdLib(std::string libname)
Definition: RModel.hxx:77
void AddOperator(std::unique_ptr< ROperator > op, int order_execution=-1)
Definition: RModel.cxx:115
static std::vector< size_t > GetDataFromTuple(PyObject *tupleObject)
Utility function which retrieves and returns the values of the Tuple object as a vector of size_t.
static const char * PyStringAsString(PyObject *string)
Returns const char* from Python string in PyObject.
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=Py_single_input)
Execute Python code from string.
Basic string class.
Definition: TString.h:136
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2336
std::unique_ptr< ROperator > MakeKerasPermute(PyObject *fLayer)
Prepares a ROperator object for Keras Permute layer.
std::unordered_map< std::string, std::unique_ptr< ROperator >(*)(PyObject *fLayer)> KerasMethodMap
void AddKerasLayer(RModel &rmodel, PyObject *fLayer)
Adds equivalent ROperator with respect to Keras model layer into the referenced RModel object.
std::unique_ptr< ROperator > MakeKerasDense(PyObject *fLayer)
Prepares a ROperator object for Keras Dense Layer.
std::unique_ptr< ROperator > MakeKerasReLU(PyObject *fLayer)
Prepares a ROperator object for Keras ReLU activation.
const KerasMethodMapWithActivation mapKerasLayerWithActivation
std::unordered_map< std::string, std::unique_ptr< ROperator >(*)(PyObject *fLayer)> KerasMethodMapWithActivation
std::unique_ptr< ROperator > MakeKerasSigmoid(PyObject *fLayer)
Prepares a ROperator object for Keras Sigmoid activation.
std::unique_ptr< ROperator > MakeKerasSelu(PyObject *fLayer)
Prepares a ROperator object for Keras Selu activation.
std::unique_ptr< ROperator > MakeKerasActivation(PyObject *fLayer)
Prepares a ROperator object for Keras activation layer.
static void(&) PyRunString(TString, PyObject *, PyObject *)
static const char *(&) PyStringAsString(PyObject *)
RModel Parse(std::string filename)
Parser function for translatng Keras .h5 model into a RModel object.
std::string ConvertTypeToString(ETensorType type)
ETensorType ConvertStringToType(std::string type)
create variable transformations