Loading [MathJax]/extensions/tex2jax.js
Logo ROOT   6.10/09
Reference Guide
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
RegressionKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 from ROOT import TMVA, TFile, TTree, TCut
4 from subprocess import call
5 from os.path import isfile
6 
7 from keras.models import Sequential
8 from keras.layers.core import Dense, Activation
9 from keras.regularizers import l2
10 from keras import initializations
11 from keras.optimizers import SGD
12 
13 # Setup TMVA
16 
17 output = TFile.Open('TMVA.root', 'RECREATE')
18 factory = TMVA.Factory('TMVARegression', output,
19  '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
20 
21 # Load data
22 if not isfile('tmva_reg_example.root'):
23  call(['curl', '-O', 'http://root.cern.ch/files/tmva_reg_example.root'])
24 
25 data = TFile.Open('tmva_reg_example.root')
26 tree = data.Get('TreeR')
27 
28 dataloader = TMVA.DataLoader('dataset')
29 for branch in tree.GetListOfBranches():
30  name = branch.GetName()
31  if name != 'fvalue':
32  dataloader.AddVariable(name)
33 dataloader.AddTarget('fvalue')
34 
35 dataloader.AddRegressionTree(tree, 1.0)
36 dataloader.PrepareTrainingAndTestTree(TCut(''),
37  'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')
38 
39 # Generate model
40 
41 # Define initialization
42 def normal(shape, name=None):
43  return initializations.normal(shape, scale=0.05, name=name)
44 
45 # Define model
46 model = Sequential()
47 model.add(Dense(64, init=normal, activation='tanh', W_regularizer=l2(1e-5), input_dim=2))
48 #model.add(Dense(32, init=normal, activation='tanh', W_regularizer=l2(1e-5)))
49 model.add(Dense(1, init=normal, activation='linear'))
50 
51 # Set loss and optimizer
52 model.compile(loss='mean_squared_error', optimizer=SGD(lr=0.01))
53 
54 # Store model to file
55 model.save('model.h5')
56 model.summary()
57 
58 # Book methods
59 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
60  'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
61 factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
62  '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
63 
64 # Run TMVA
65 factory.TrainAllMethods()
66 factory.TestAllMethods()
67 factory.EvaluateAllMethods()
static Tools & Instance()
Definition: Tools.cxx:75
static void PyInitialize()
Initialize Python interpreter.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3909
def normal(shape, name=None)
This is the main MVA steering class.
Definition: Factory.h:81