Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RegressionKeras.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_keras
4## \notebook -nodraw
5## This tutorial shows how to do regression in TMVA with neural networks
6## trained with keras.
7##
8## \macro_code
9##
10## \date 2017
11## \author TMVA Team
12
13from ROOT import TMVA, TFile, TTree, TCut
14from subprocess import call
15from os.path import isfile
16
17from tensorflow.keras.models import Sequential
18from tensorflow.keras.layers import Dense, Activation
19from tensorflow.keras.optimizers import SGD
20
21# Setup TMVA
24
25output = TFile.Open('TMVA.root', 'RECREATE')
26factory = TMVA.Factory('TMVARegression', output,
27 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
28
29# Load data
30if not isfile('tmva_reg_example.root'):
31 call(['curl', '-O', 'http://root.cern.ch/files/tmva_reg_example.root'])
32
33data = TFile.Open('tmva_reg_example.root')
34tree = data.Get('TreeR')
35
36dataloader = TMVA.DataLoader('dataset')
37for branch in tree.GetListOfBranches():
38 name = branch.GetName()
39 if name != 'fvalue':
40 dataloader.AddVariable(name)
41dataloader.AddTarget('fvalue')
42
43dataloader.AddRegressionTree(tree, 1.0)
44dataloader.PrepareTrainingAndTestTree(TCut(''),
45 'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')
46
47# Generate model
48
49# Define model
50model = Sequential()
51model.add(Dense(64, activation='tanh', input_dim=2))
52model.add(Dense(1, activation='linear'))
53
54# Set loss and optimizer
55model.compile(loss='mean_squared_error', optimizer=SGD(lr=0.01))
56
57# Store model to file
58model.save('model.h5')
59model.summary()
60
61# Book methods
62factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
63 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
64factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
65 '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
66
67# Run TMVA
68factory.TrainAllMethods()
69factory.TestAllMethods()
70factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
Definition TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3997
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:75