Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_RNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook
4## TMVA Classification Example Using a Recurrent Neural Network
5##
6## This is an example of using a RNN in TMVA. We do classification using a toy time dependent data set
7## that is generated when running this example macro
8##
9## \macro_image
10## \macro_output
11## \macro_code
12##
13## \author Harshal Shende
14
15
16# TMVA Classification Example Using a Recurrent Neural Network
17
18# This is an example of using a RNN in TMVA.
19# We do the classification using a toy data set containing a time series of data sample ntimes
20# and with dimension ndim that is generated when running the provided function `MakeTimeData (nevents, ntime, ndim)`
21
22
23import ROOT
24
25num_threads = 4 # use max 4 threads
26# do enable MT running
27if "imt" in ROOT.gROOT.GetConfigFeatures():
28 ROOT.EnableImplicitMT(num_threads)
29 # switch off MT in OpenBLAS to avoid conflict with tbb
30 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")
31 print("Running with nthreads = {}".format(ROOT.GetThreadPoolSize()))
32else:
33 print("Running in serial mode since ROOT does not support MT")
34
35
36TMVA = ROOT.TMVA
37TFile = ROOT.TFile
38
39import os
40import importlib
41
42
45
46
47## Helper function to generate the time data set
48## make some time data but not of fixed length.
49## use a poisson with mu = 5 and truncated at 10
50
51
52def MakeTimeData(n, ntime, ndim):
53 # ntime = 10;
54 # ndim = 30; // number of dim/time
55
56 fname = "time_data_t" + str(ntime) + "_d" + str(ndim) + ".root"
57 v1 = []
58 v2 = []
59
60 for i in range(ntime):
61 v1.append(ROOT.TH1D("h1_" + str(i), "h1", ndim, 0, 10))
62 v2.append(ROOT.TH1D("h2_" + str(i), "h2", ndim, 0, 10))
63
64 f1 = ROOT.TF1("f1", "gaus")
65 f2 = ROOT.TF1("f2", "gaus")
66
67 sgn = ROOT.TTree("sgn", "sgn")
68 bkg = ROOT.TTree("bkg", "bkg")
69 f = TFile(fname, "RECREATE")
70
71 x1 = []
72 x2 = []
73
74 for i in range(ntime):
75 x1.append(ROOT.std.vector["float"](ndim))
76 x2.append(ROOT.std.vector["float"](ndim))
77
78 for i in range(ntime):
79 bkg.Branch("vars_time" + str(i), "std::vector<float>", x1[i])
80 sgn.Branch("vars_time" + str(i), "std::vector<float>", x2[i])
81
82 sgn.SetDirectory(f)
83 bkg.SetDirectory(f)
84 ROOT.gRandom.SetSeed(0)
85
86 mean1 = ROOT.std.vector["double"](ntime)
87 mean2 = ROOT.std.vector["double"](ntime)
88 sigma1 = ROOT.std.vector["double"](ntime)
89 sigma2 = ROOT.std.vector["double"](ntime)
90
91 for j in range(ntime):
92 mean1[j] = 5.0 + 0.2 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
93 mean2[j] = 5.0 + 0.2 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
94 sigma1[j] = 4 + 0.3 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
95 sigma2[j] = 4 + 0.3 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
96
97 for i in range(n):
98 if i % 1000 == 0:
99 print("Generating event ... %d", i)
100
101 for j in range(ntime):
102 h1 = v1[j]
103 h2 = v2[j]
104 h1.Reset()
105 h2.Reset()
106
107 f1.SetParameters(1, mean1[j], sigma1[j])
108 f2.SetParameters(1, mean2[j], sigma2[j])
109
110 h1.FillRandom("f1", 1000)
111 h2.FillRandom("f2", 1000)
112
113 for k in range(ntime):
114 # std::cout << j*10+k << " ";
115 x1[j][k] = h1.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
116 x2[j][k] = h2.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
117
118 sgn.Fill()
119 bkg.Fill()
120
121 if n == 1:
122 c1 = ROOT.TCanvas()
123 c1.Divide(ntime, 2)
124 for j in range(ntime):
125 c1.cd(j + 1)
126 v1[j].Draw()
127 for j in range(ntime):
128 c1.cd(ntime + j + 1)
129 v2[j].Draw()
130
131 ROOT.gPad.Update()
132
133 if n > 1:
134 sgn.Write()
135 bkg.Write()
136 sgn.Print()
137 bkg.Print()
138 f.Close()
139
140
141## macro for performing a classification using a Recurrent Neural Network
142## @param use_type
143## use_type = 0 use Simple RNN network
144## use_type = 1 use LSTM network
145## use_type = 2 use GRU
146## use_type = 3 build 3 different networks with RNN, LSTM and GRU
147
148
149use_type = 1
150ninput = 30
151ntime = 10
152batchSize = 100
153maxepochs = 10
154
155nTotEvts = 2000 # total events to be generated for signal or background
156
157useKeras = True
158
159useTMVA_RNN = True
160useTMVA_DNN = True
161useTMVA_BDT = False
162
163tf_spec = importlib.util.find_spec("tensorflow")
164if tf_spec is None:
165 useKeras = False
166 ROOT.Warning("TMVA_RNN_Classificaton","Skip using Keras since tensorflow is not installed")
167
168
169rnn_types = ["RNN", "LSTM", "GRU"]
170use_rnn_type = [1, 1, 1]
171
172if 0 <= use_type < 3:
173 use_rnn_type = [0, 0, 0]
174 use_rnn_type[use_type] = 1
175
176useGPU = True # use GPU for TMVA if available
177
178useGPU = "tmva-gpu" in ROOT.gROOT.GetConfigFeatures()
179useTMVA_RNN = ("tmva-cpu" in ROOT.gROOT.GetConfigFeatures()) or useGPU
180
181if useTMVA_RNN:
182 ROOT.Warning(
183 "TMVA_RNN_Classification",
184 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN",
185 )
186
187archString = "GPU" if useGPU else "CPU"
188
189writeOutputFile = True
190
191rnn_type = "RNN"
192
193if "tmva-pymva" in ROOT.gROOT.GetConfigFeatures():
195else:
196 useKeras = False
197
198
199
200inputFileName = "time_data_t10_d30.root"
201
202fileDoesNotExist = ROOT.gSystem.AccessPathName(inputFileName)
203
204# if file does not exists create it
205if fileDoesNotExist:
206 MakeTimeData(nTotEvts, ntime, ninput)
207
208
209inputFile = TFile.Open(inputFileName)
210if inputFile is None:
211 raise ROOT.Error("Error opening input file %s - exit", inputFileName.Data())
212
213
214print("--- RNNClassification : Using input file: {}".format(inputFile.GetName()))
215
216# Create a ROOT output file where TMVA will store ntuples, histograms, etc.
217outfileName = "data_RNN_" + archString + ".root"
218outputFile = None
219
220
221if writeOutputFile:
222 outputFile = TFile.Open(outfileName, "RECREATE")
223
224
225## Declare Factory
226
227# Create the Factory class. Later you can choose the methods
228# whose performance you'd like to investigate.
229
230# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to
231# pass
232
233# - The first argument is the base of the name of all the output
234# weightfiles in the directory weight/ that will be created with the
235# method parameters
236
237# - The second argument is the output file for the training results
238#
239# - The third argument is a string option defining some general configuration for the TMVA session.
240# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in
241# the option string
242
243
244# // Creating the factory object
245factory = TMVA.Factory(
246 "TMVAClassification",
247 outputFile,
248 V=False,
249 Silent=False,
250 Color=True,
251 DrawProgressBar=True,
252 Transformations=None,
253 Correlations=False,
254 AnalysisType="Classification",
255 ModelPersistence=True,
256)
257dataloader = TMVA.DataLoader("dataset")
258
259signalTree = inputFile.Get("sgn")
260background = inputFile.Get("bkg")
261
262nvar = ninput * ntime
263
264## add variables - use new AddVariablesArray function
265for i in range(ntime):
266 dataloader.AddVariablesArray("vars_time" + str(i), ninput)
267
268
269dataloader.AddSignalTree(signalTree, 1.0)
270dataloader.AddBackgroundTree(background, 1.0)
271
272# check given input
273datainfo = dataloader.GetDataSetInfo()
274vars = datainfo.GetListOfVariables()
275print("number of variables is {}".format(vars.size()))
276
277
278for v in vars:
279 print(v)
280
281nTrainSig = 0.8 * nTotEvts
282nTrainBkg = 0.8 * nTotEvts
283
284# Apply additional cuts on the signal and background samples (can be different)
285mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
286mycutb = ""
287
288# build the string options for DataLoader::PrepareTrainingAndTestTree
289dataloader.PrepareTrainingAndTestTree(
290 mycuts,
291 mycutb,
292 nTrain_Signal=nTrainSig,
293 nTrain_Background=nTrainBkg,
294 SplitMode="Random",
295 SplitSeed=100,
296 NormMode="NumEvents",
297 V=False,
298 CalcCorrelations=False,
299)
300
301print("prepared DATA LOADER ")
302
303
304## Book TMVA recurrent models
305
306# Book the different types of recurrent models in TMVA (SimpleRNN, LSTM or GRU)
307
308
309if useTMVA_RNN:
310 for i in range(3):
311 if not use_rnn_type[i]:
312 continue
313
314 rnn_type = rnn_types[i]
315
316 ## Define RNN layer layout
317 ## it should be LayerType (RNN or LSTM or GRU) | number of units | number of inputs | time steps | remember output (typically no=0 | return full sequence
318 rnnLayout = str(rnn_type) + "|10|" + str(ninput) + "|" + str(ntime) + "|0|1,RESHAPE|FLAT,DENSE|64|TANH,LINEAR"
319
320 ## Defining Training strategies. Different training strings can be concatenate. Use however only one
321 trainingString1 = "LearningRate=1e-3,Momentum=0.0,Repetitions=1,ConvergenceSteps=5,BatchSize=" + str(batchSize)
322 trainingString1 += ",TestRepetitions=1,WeightDecay=1e-2,Regularization=None,MaxEpochs=" + str(maxepochs)
323 trainingString1 += "Optimizer=ADAM,DropConfig=0.0+0.+0.+0."
324
325 ## define the inputlayout string for RNN
326 ## the input data should be organize as following:
327 ##/ input layout for RNN: time x ndim
328 ## add after RNN a reshape layer (needed top flatten the output) and a dense layer with 64 units and a last one
329 ## Note the last layer is linear because when using Crossentropy a Sigmoid is applied already
330 ## Define the full RNN Noption string adding the final options for all network
331 rnnName = "TMVA_" + str(rnn_type)
332 factory.BookMethod(
333 dataloader,
334 TMVA.Types.kDL,
335 rnnName,
336 H=False,
337 V=True,
338 ErrorStrategy="CROSSENTROPY",
339 VarTransform=None,
340 WeightInitialization="XAVIERUNIFORM",
341 ValidationSize=0.2,
342 RandomSeed=1234,
343 InputLayout=str(ntime) + "|" + str(ninput),
344 Layout=rnnLayout,
345 TrainingStrategy=trainingString1,
346 Architecture=archString
347 )
348
349
350## Book TMVA fully connected dense layer models
351if useTMVA_DNN:
352 # Method DL with Dense Layer
353 # Training strategies.
354 trainingString1 = ROOT.TString(
355 "LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
356 "ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
357 "WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
358 "DropConfig=0.0+0.+0.+0.,Optimizer=ADAM:"
359 ) # + "|" + trainingString2
360 # General Options.
361 trainingString1.Append(archString)
362 dnnName = "TMVA_DNN"
363 factory.BookMethod(
364 dataloader,
365 TMVA.Types.kDL,
366 dnnName,
367 H=False,
368 V=True,
369 ErrorStrategy="CROSSENTROPY",
370 VarTransform=None,
371 WeightInitialization="XAVIER",
372 RandomSeed=0,
373 InputLayout="1|1|" + str(ntime * ninput),
374 Layout="DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR",
375 TrainingStrategy=trainingString1
376 )
377
378
379## Book Keras recurrent models
380
381# Book the different types of recurrent models in Keras (SimpleRNN, LSTM or GRU)
382
383
384if useKeras:
385 for i in range(3):
386 if use_rnn_type[i]:
387 modelName = "model_" + rnn_types[i] + ".h5"
388 trainedModelName = "trained_" + modelName
389 print("Building recurrent keras model using a", rnn_types[i], "layer")
390 # create python script which can be executed
391 # create 2 conv2d layer + maxpool + dense
392 from tensorflow.keras.models import Sequential
393 from tensorflow.keras.optimizers import Adam
394
395 # from keras.initializers import TruncatedNormal
396 # from keras import initializations
397 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, SimpleRNN, GRU, LSTM, Reshape, BatchNormalization
398
399 model = Sequential()
400 model.add(Reshape((10, 30), input_shape=(10 * 30,)))
401 # add recurrent neural network depending on type / Use option to return the full output
402 if rnn_types[i] == "LSTM":
403 model.add(LSTM(units=10, return_sequences=True))
404 elif rnn_types[i] == "GRU":
405 model.add(GRU(units=10, return_sequences=True))
406 else:
407 model.add(SimpleRNN(units=10, return_sequences=True))
408 # m.AddLine("model.add(BatchNormalization())");
409 model.add(Flatten()) # needed if returning the full time output sequence
410 model.add(Dense(64, activation="tanh"))
411 model.add(Dense(2, activation="sigmoid"))
412 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
413 model.save(modelName)
414 model.summary()
415 print("saved recurrent model", modelName)
416
417 if not os.path.exists(modelName):
418 useKeras = False
419 print("Error creating Keras recurrent model file - Skip using Keras")
420 else:
421 # book PyKeras method only if Keras model could be created
422 print("Booking Keras model ", rnn_types[i])
423 factory.BookMethod(
424 dataloader,
425 TMVA.Types.kPyKeras,
426 "PyKeras_" + rnn_types[i],
427 H=True,
428 V=False,
429 VarTransform=None,
430 FilenameModel=modelName,
431 FilenameTrainedModel="trained_" + modelName,
432 NumEpochs=maxepochs,
433 BatchSize=batchSize,
434 GpuOptions="allow_growth=True",
435 )
436
437
438# use BDT in case not using Keras or TMVA DL
439if not useKeras or not useTMVA_BDT:
440 useTMVA_BDT = True
441
442
443## Book TMVA BDT
444
445
446if useTMVA_BDT:
447 factory.BookMethod(
448 dataloader,
449 TMVA.Types.kBDT,
450 "BDTG",
451 H=True,
452 V=False,
453 NTrees=100,
454 MinNodeSize="2.5%",
455 BoostType="Grad",
456 Shrinkage=0.10,
457 UseBaggedBoost=True,
458 BaggedSampleFraction=0.5,
459 nCuts=20,
460 MaxDepth=2,
461 )
462
463
464## Train all methods
465factory.TrainAllMethods()
466
467print("nthreads = {}".format(ROOT.GetThreadPoolSize()))
468
469# ---- Evaluate all MVAs using the set of test events
470factory.TestAllMethods()
471
472# ----- Evaluate and compare performance of all configured MVAs
473factory.EvaluateAllMethods()
474
475# check method
476
477# plot ROC curve
478c1 = factory.GetROCCurve(dataloader)
479c1.Draw()
480
481if outputFile:
482 outputFile.Close()
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4082
static Config & Instance()
static function: returns TMVA instance
Definition Config.cxx:98
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:71
void EnableImplicitMT(UInt_t numthreads=0)
Enable ROOT's implicit multi-threading for all objects and methods that provide an internal paralleli...
Definition TROOT.cxx:539
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.
Definition TROOT.cxx:577
th1 Draw()