Logo ROOT  
Reference Guide
GenerateModel.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_keras
4## \notebook -nodraw
5## This tutorial shows how to define and generate a keras model for use with
6## TMVA.
7##
8## \macro_code
9##
10## \date 2017
11## \author TMVA Team
12
13from keras.models import Sequential
14from keras.layers.core import Dense, Activation
15from keras.regularizers import l2
16from keras.optimizers import SGD
17
18# Setup the model here
19num_input_nodes = 4
20num_output_nodes = 2
21num_hidden_layers = 1
22nodes_hidden_layer = 64
23l2_val = 1e-5
24
25model = Sequential()
26
27# Hidden layer 1
28# NOTE: Number of input nodes need to be defined in this layer
29model.add(Dense(nodes_hidden_layer, activation='relu', W_regularizer=l2(l2_val), input_dim=num_input_nodes))
30
31# Hidden layer 2 to num_hidden_layers
32# NOTE: Here, you can do what you want
33for k in range(num_hidden_layers-1):
34 model.add(Dense(nodes_hidden_layer, activation='relu', W_regularizer=l2(l2_val)))
35
36# Ouput layer
37# NOTE: Use following output types for the different tasks
38# Binary classification: 2 output nodes with 'softmax' activation
39# Regression: 1 output with any activation ('linear' recommended)
40# Multiclass classification: (number of classes) output nodes with 'softmax' activation
41model.add(Dense(num_output_nodes, activation='softmax'))
42
43# Compile model
44# NOTE: Use following settings for the different tasks
45# Any classification: 'categorical_crossentropy' is recommended loss function
46# Regression: 'mean_squared_error' is recommended loss function
47model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
48
49# Save model
50model.save('model.h5')
51
52# Additional information about the model
53# NOTE: This is not needed to run the model
54
55# Print summary
56model.summary()
57
58# Visualize model as graph
59try:
60 from keras.utils.visualize_util import plot
61 plot(model, to_file='model.png', show_shapes=True)
62except:
63 print('[INFO] Failed to make model plot')