Logo ROOT   6.16/01
Reference Guide
RegressionKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2
3from ROOT import TMVA, TFile, TTree, TCut
4from subprocess import call
5from os.path import isfile
6
7from keras.models import Sequential
8from keras.layers.core import Dense, Activation
9from keras.regularizers import l2
10from keras.optimizers import SGD
11
12# Setup TMVA
15
16output = TFile.Open('TMVA.root', 'RECREATE')
17factory = TMVA.Factory('TMVARegression', output,
18 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
19
20# Load data
21if not isfile('tmva_reg_example.root'):
22 call(['curl', '-O', 'http://root.cern.ch/files/tmva_reg_example.root'])
23
24data = TFile.Open('tmva_reg_example.root')
25tree = data.Get('TreeR')
26
27dataloader = TMVA.DataLoader('dataset')
28for branch in tree.GetListOfBranches():
29 name = branch.GetName()
30 if name != 'fvalue':
31 dataloader.AddVariable(name)
32dataloader.AddTarget('fvalue')
33
34dataloader.AddRegressionTree(tree, 1.0)
35dataloader.PrepareTrainingAndTestTree(TCut(''),
36 'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')
37
38# Generate model
39
40# Define model
41model = Sequential()
42model.add(Dense(64, activation='tanh', W_regularizer=l2(1e-5), input_dim=2))
43model.add(Dense(1, activation='linear'))
44
45# Set loss and optimizer
46model.compile(loss='mean_squared_error', optimizer=SGD(lr=0.01))
47
48# Store model to file
49model.save('model.h5')
50model.summary()
51
52# Book methods
53factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
54 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
55factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
56 '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
57
58# Run TMVA
59factory.TrainAllMethods()
60factory.TestAllMethods()
61factory.EvaluateAllMethods()
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseGeneralPurpose, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3975
This is the main MVA steering class.
Definition: Factory.h:81
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition: Tools.cxx:75