Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_Keras.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook -nodraw
4/// This macro provides a simple example for the parsing of Keras .h5 file
5/// into RModel object and further generating the .hxx header files for inference.
6///
7/// \macro_code
8/// \macro_output
9/// \author Sanjiban Sengupta
10
11using namespace TMVA::Experimental;
12
13TString pythonSrc = "\
14import os\n\
15os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n\
16\n\
17import numpy as np\n\
18from tensorflow.keras.models import Model\n\
19from tensorflow.keras.layers import Input,Dense,Activation,ReLU\n\
20from tensorflow.keras.optimizers import SGD\n\
21\n\
22input=Input(shape=(64,),batch_size=4)\n\
23x=Dense(32)(input)\n\
24x=Activation('relu')(x)\n\
25x=Dense(16,activation='relu')(x)\n\
26x=Dense(8,activation='relu')(x)\n\
27x=Dense(4)(x)\n\
28output=ReLU()(x)\n\
29model=Model(inputs=input,outputs=output)\n\
30\n\
31randomGenerator=np.random.RandomState(0)\n\
32x_train=randomGenerator.rand(4,64)\n\
33y_train=randomGenerator.rand(4,4)\n\
34\n\
35model.compile(loss='mean_squared_error', optimizer=SGD(learning_rate=0.01))\n\
36model.fit(x_train, y_train, epochs=5, batch_size=4)\n\
37model.save('KerasModel.h5')\n";
38
39
40void TMVA_SOFIE_Keras(const char * modelFile = nullptr, bool printModelInfo = true){
41
42 //Running the Python script to generate Keras .h5 file
44
45 if (modelFile == nullptr) {
46 TMacro m;
47 m.AddLine(pythonSrc);
48 m.SaveSource("make_keras_model.py");
49 gSystem->Exec(TMVA::Python_Executable() + " make_keras_model.py");
50 modelFile = "KerasModel.h5";
51 }
52
53 //Parsing the saved Keras .h5 file into RModel object
54 SOFIE::RModel model = SOFIE::PyKeras::Parse(modelFile);
55
56
57 //Generating inference code
58 model.Generate();
59 // generate output header. By default it will be modelName.hxx
60 model.OutputGenerated();
61
62 if (!printModelInfo) return;
63
64 //Printing required input tensors
65 std::cout<<"\n\n";
67
68 //Printing initialized tensors (weights)
69 std::cout<<"\n\n";
71
72 //Printing intermediate tensors
73 std::cout<<"\n\n";
75
76 //Checking if tensor already exist in model
77 std::cout<<"\n\nTensor \"dense2bias0\" already exist: "<<std::boolalpha<<model.CheckIfTensorAlreadyExist("dense2bias0")<<"\n\n";
78 std::vector<size_t> tensorShape = model.GetTensorShape("dense2bias0");
79 std::cout<<"Shape of tensor \"dense2bias0\": ";
80 for(auto& it:tensorShape){
81 std::cout<<it<<",";
82 }
83 std::cout<<"\n\nData type of tensor \"dense2bias0\": ";
84 SOFIE::ETensorType tensorType = model.GetTensorType("dense2bias0");
85 std::cout<<SOFIE::ConvertTypeToString(tensorType);
86
87 //Printing generated inference code
88 std::cout<<"\n\n";
89 model.PrintGenerated();
90}
R__EXTERN TSystem * gSystem
Definition TSystem.h:555
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:91
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:116
void OutputGenerated(std::string filename="", bool append=false)
Definition RModel.cxx:933
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
void Generate(std::underlying_type_t< Options > options, int batchSize=-1, long pos=0)
Definition RModel.cxx:542
static void PyInitialize()
Initialize Python interpreter.
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
Basic string class.
Definition TString.h:139
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition TSystem.cxx:653
TString Python_Executable()
Function to find current Python executable used by ROOT If "Python3" is installed,...
TMarker m
Definition textangle.C:8