Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MulticlassKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_keras
4## \notebook -nodraw
5## This tutorial shows how to do multiclass classification in TMVA with neural
6## networks trained with keras.
7##
8## \macro_code
9##
10## \date 2017
11## \author TMVA Team
12
13from ROOT import TMVA, TFile, TTree, TCut, gROOT
14from os.path import isfile
15
16from tensorflow.keras.models import Sequential
17from tensorflow.keras.layers import Dense, Activation
18from tensorflow.keras.optimizers import SGD
19
20# Setup TMVA
23
24output = TFile.Open('TMVA.root', 'RECREATE')
25factory = TMVA.Factory('TMVAClassification', output,
26 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
27
28# Load data
29if not isfile('tmva_example_multiple_background.root'):
30 createDataMacro = str(gROOT.GetTutorialDir()) + '/tmva/createData.C'
31 print(createDataMacro)
32 gROOT.ProcessLine('.L {}'.format(createDataMacro))
33 gROOT.ProcessLine('create_MultipleBackground(4000)')
34
35data = TFile.Open('tmva_example_multiple_background.root')
36signal = data.Get('TreeS')
37background0 = data.Get('TreeB0')
38background1 = data.Get('TreeB1')
39background2 = data.Get('TreeB2')
40
41dataloader = TMVA.DataLoader('dataset')
42for branch in signal.GetListOfBranches():
43 dataloader.AddVariable(branch.GetName())
44
45dataloader.AddTree(signal, 'Signal')
46dataloader.AddTree(background0, 'Background_0')
47dataloader.AddTree(background1, 'Background_1')
48dataloader.AddTree(background2, 'Background_2')
49dataloader.PrepareTrainingAndTestTree(TCut(''),
50 'SplitMode=Random:NormMode=NumEvents:!V')
51
52# Generate model
53
54# Define model
55model = Sequential()
56model.add(Dense(32, activation='relu', input_dim=4))
57model.add(Dense(4, activation='softmax'))
58
59# Set loss and optimizer
60model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
61
62# Store model to file
63model.save('model.h5')
64model.summary()
65
66# Book methods
67factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
68 '!H:!V:Fisher:VarTransform=D,G')
69factory.BookMethod(dataloader, TMVA.Types.kPyKeras, "PyKeras",
70 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
71
72# Run TMVA
73factory.TrainAllMethods()
74factory.TestAllMethods()
75factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
Definition TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3997
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:75