Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
GenerateModel.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_keras
4## \notebook -nodraw
5## This tutorial shows how to define and generate a keras model for use with
6## TMVA.
7##
8## \macro_code
9##
10## \date 2017
11## \author TMVA Team
12
13from tensorflow.keras.models import Sequential
14from tensorflow.keras.layers import Dense, Activation
15from tensorflow.keras.regularizers import l2
16from tensorflow.keras.optimizers import SGD
17from tensorflow.keras.utils import plot_model
18
19# Setup the model here
20num_input_nodes = 4
21num_output_nodes = 2
22num_hidden_layers = 1
23nodes_hidden_layer = 64
24l2_val = 1e-5
25
26model = Sequential()
27
28# Hidden layer 1
29# NOTE: Number of input nodes need to be defined in this layer
30model.add(Dense(nodes_hidden_layer, activation='relu', kernel_regularizer=l2(l2_val), input_dim=num_input_nodes))
31
32# Hidden layer 2 to num_hidden_layers
33# NOTE: Here, you can do what you want
34for k in range(num_hidden_layers-1):
35 model.add(Dense(nodes_hidden_layer, activation='relu', kernel_regularizer=l2(l2_val)))
36
37# Output layer
38# NOTE: Use following output types for the different tasks
39# Binary classification: 2 output nodes with 'softmax' activation
40# Regression: 1 output with any activation ('linear' recommended)
41# Multiclass classification: (number of classes) output nodes with 'softmax' activation
42model.add(Dense(num_output_nodes, activation='softmax'))
43
44# Compile model
45# NOTE: Use following settings for the different tasks
46# Any classification: 'categorical_crossentropy' is recommended loss function
47# Regression: 'mean_squared_error' is recommended loss function
48model.compile(loss='categorical_crossentropy', optimizer=SGD(learning_rate=0.01), weighted_metrics=['accuracy',])
49
50# Save model
51model.save('model.h5')
52
53# Additional information about the model
54# NOTE: This is not needed to run the model
55
56# Print summary
57model.summary()
58
59# Visualize model as graph
60try:
61 plot_model(model, to_file='model.png', show_shapes=True)
62except:
63 print('[INFO] Failed to make model plot')