Logo ROOT   6.14/05
Reference Guide
ClassificationKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from ROOT import TMVA, TFile, TTree, TCut
4 from subprocess import call
5 from os.path import isfile
6 
7 from keras.models import Sequential
8 from keras.layers import Dense, Activation
9 from keras.regularizers import l2
10 from keras.optimizers import SGD
11 
12 # Setup TMVA
15 
16 output = TFile.Open('TMVA.root', 'RECREATE')
17 factory = TMVA.Factory('TMVAClassification', output,
18  '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
19 
20 # Load data
21 if not isfile('tmva_class_example.root'):
22  call(['curl', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])
23 
24 data = TFile.Open('tmva_class_example.root')
25 signal = data.Get('TreeS')
26 background = data.Get('TreeB')
27 
28 dataloader = TMVA.DataLoader('dataset')
29 for branch in signal.GetListOfBranches():
30  dataloader.AddVariable(branch.GetName())
31 
32 dataloader.AddSignalTree(signal, 1.0)
33 dataloader.AddBackgroundTree(background, 1.0)
34 dataloader.PrepareTrainingAndTestTree(TCut(''),
35  'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
36 
37 # Generate model
38 
39 # Define model
40 model = Sequential()
41 model.add(Dense(64, activation='relu', W_regularizer=l2(1e-5), input_dim=4))
42 model.add(Dense(2, activation='softmax'))
43 
44 # Set loss and optimizer
45 model.compile(loss='categorical_crossentropy',
46  optimizer=SGD(lr=0.01), metrics=['accuracy', ])
47 
48 # Store model to file
49 model.save('model.h5')
50 model.summary()
51 
52 # Book methods
53 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
54  '!H:!V:Fisher:VarTransform=D,G')
55 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
56  'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
57 
58 # Run training, test and evaluation
59 factory.TrainAllMethods()
60 factory.TestAllMethods()
61 factory.EvaluateAllMethods()
static Tools & Instance()
Definition: Tools.cxx:75
static void PyInitialize()
Initialize Python interpreter.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3976
This is the main MVA steering class.
Definition: Factory.h:81