Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_RNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook
4## TMVA Classification Example Using a Recurrent Neural Network
5##
6## This is an example of using a RNN in TMVA. We do classification using a toy time dependent data set
7## that is generated when running this example macro
8##
9## \macro_code
10##
11## \author Harshal Shende
12
13
14# TMVA Classification Example Using a Recurrent Neural Network
15
16# This is an example of using a RNN in TMVA.
17# We do the classification using a toy data set containing a time series of data sample ntimes
18# and with dimension ndim that is generated when running the provided function `MakeTimeData (nevents, ntime, ndim)`
19
20
21import ROOT
22
23num_threads = 4 # use max 4 threads
24# do enable MT running
26 ROOT.EnableImplicitMT(num_threads)
27 # switch off MT in OpenBLAS to avoid conflict with tbb
28 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")
29 print("Running with nthreads = {}".format(ROOT.GetThreadPoolSize()))
30else:
31 print("Running in serial mode since ROOT does not support MT")
32
33
34TMVA = ROOT.TMVA
35TFile = ROOT.TFile
36
37import os
38
41
42
43## Helper function to generate the time data set
44## make some time data but not of fixed length.
45## use a poisson with mu = 5 and truncated at 10
46
47
48def MakeTimeData(n, ntime, ndim):
49 # ntime = 10;
50 # ndim = 30; // number of dim/time
51
52 fname = "time_data_t" + str(ntime) + "_d" + str(ndim) + ".root"
53 v1 = []
54 v2 = []
55
56 for i in range(ntime):
57 v1.append(ROOT.TH1D("h1_" + str(i), "h1", ndim, 0, 10))
58 v2.append(ROOT.TH1D("h2_" + str(i), "h2", ndim, 0, 10))
59
60 f1 = ROOT.TF1("f1", "gaus")
61 f2 = ROOT.TF1("f2", "gaus")
62
63 sgn = ROOT.TTree("sgn", "sgn")
64 bkg = ROOT.TTree("bkg", "bkg")
65 f = TFile(fname, "RECREATE")
66
67 x1 = []
68 x2 = []
69
70 for i in range(ntime):
71 x1.append(ROOT.std.vector["float"](ndim))
72 x2.append(ROOT.std.vector["float"](ndim))
73
74 for i in range(ntime):
75 bkg.Branch("vars_time" + str(i), "std::vector<float>", x1[i])
76 sgn.Branch("vars_time" + str(i), "std::vector<float>", x2[i])
77
81
82 mean1 = ROOT.std.vector["double"](ntime)
83 mean2 = ROOT.std.vector["double"](ntime)
84 sigma1 = ROOT.std.vector["double"](ntime)
85 sigma2 = ROOT.std.vector["double"](ntime)
86
87 for j in range(ntime):
88 mean1[j] = 5.0 + 0.2 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
89 mean2[j] = 5.0 + 0.2 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
90 sigma1[j] = 4 + 0.3 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
91 sigma2[j] = 4 + 0.3 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
92
93 for i in range(n):
94 if i % 1000 == 0:
95 print("Generating event ... %d", i)
96
97 for j in range(ntime):
98 h1 = v1[j]
99 h2 = v2[j]
100 h1.Reset()
101 h2.Reset()
102
103 f1.SetParameters(1, mean1[j], sigma1[j])
104 f2.SetParameters(1, mean2[j], sigma2[j])
105
106 h1.FillRandom(f1, 1000)
107 h2.FillRandom(f2, 1000)
108
109 for k in range(ntime):
110 # std::cout << j*10+k << " ";
111 x1[j][k] = h1.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
112 x2[j][k] = h2.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
113
114 sgn.Fill()
115 bkg.Fill()
116
117 if n == 1:
118 c1 = ROOT.TCanvas()
119 c1.Divide(ntime, 2)
120 for j in range(ntime):
121 c1.cd(j + 1)
122 v1[j].Draw()
123 for j in range(ntime):
124 c1.cd(ntime + j + 1)
125 v2[j].Draw()
126
128
129 if n > 1:
130 sgn.Write()
131 bkg.Write()
132 sgn.Print()
133 bkg.Print()
134 f.Close()
135
136
137## macro for performing a classification using a Recurrent Neural Network
138## @param use_type
139## use_type = 0 use Simple RNN network
140## use_type = 1 use LSTM network
141## use_type = 2 use GRU
142## use_type = 3 build 3 different networks with RNN, LSTM and GRU
143
144
145use_type = 1
146ninput = 30
147ntime = 10
148batchSize = 100
149maxepochs = 10
150
151nTotEvts = 2000 # total events to be generated for signal or background
152
153useKeras = False
154
155useTMVA_RNN = True
156useTMVA_DNN = True
157useTMVA_BDT = False
158
159if ROOT.gSystem.GetFromPipe("root-config --has-tmva-pymva") == "yes":
160 useKeras = True
161
162if useKeras:
163 try:
164 pass
165 except:
166 ROOT.Warning("TMVA_RNN_Classification", "Skip using Keras since tensorflow cannot be imported")
167 useKeras = False
168
169
170rnn_types = ["RNN", "LSTM", "GRU"]
171use_rnn_type = [1, 1, 1]
172
173if 0 <= use_type < 3:
174 use_rnn_type = [0, 0, 0]
175 use_rnn_type[use_type] = 1
176
177useGPU = True # use GPU for TMVA if available
178
179useGPU = "tmva-gpu" in ROOT.gROOT.GetConfigFeatures()
180useTMVA_RNN = ("tmva-cpu" in ROOT.gROOT.GetConfigFeatures()) or useGPU
181
182if useTMVA_RNN:
184 "TMVA_RNN_Classification",
185 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN",
186 )
187
188archString = "GPU" if useGPU else "CPU"
189
190writeOutputFile = True
191
192rnn_type = "RNN"
193
194if "tmva-pymva" in ROOT.gROOT.GetConfigFeatures():
196else:
197 useKeras = False
198
199
200
201inputFileName = "time_data_t10_d30.root"
202
203fileDoesNotExist = ROOT.gSystem.AccessPathName(inputFileName)
204
205# if file does not exists create it
206if fileDoesNotExist:
207 MakeTimeData(nTotEvts, ntime, ninput)
208
209
210inputFile = TFile.Open(inputFileName)
211if inputFile is None:
212 raise ROOT.Error("Error opening input file %s - exit", inputFileName.Data())
213
214
215print("--- RNNClassification : Using input file: {}".format(inputFile.GetName()))
216
217# Create a ROOT output file where TMVA will store ntuples, histograms, etc.
218outfileName = "data_RNN_" + archString + ".root"
219outputFile = None
220
221
222if writeOutputFile:
223 outputFile = TFile.Open(outfileName, "RECREATE")
224
225
226## Declare Factory
227
228# Create the Factory class. Later you can choose the methods
229# whose performance you'd like to investigate.
230
231# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to
232# pass
233
234# - The first argument is the base of the name of all the output
235# weightfiles in the directory weight/ that will be created with the
236# method parameters
237
238# - The second argument is the output file for the training results
239#
240# - The third argument is a string option defining some general configuration for the TMVA session.
241# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in
242# the option string
243
244
245# // Creating the factory object
246factory = TMVA.Factory(
247 "TMVAClassification",
248 outputFile,
249 V=False,
250 Silent=False,
251 Color=True,
252 DrawProgressBar=True,
253 Transformations=None,
254 Correlations=False,
255 AnalysisType="Classification",
256 ModelPersistence=True,
257)
258dataloader = TMVA.DataLoader("dataset")
259
260signalTree = inputFile.Get("sgn")
261background = inputFile.Get("bkg")
262
263nvar = ninput * ntime
264
265## add variables - use new AddVariablesArray function
266for i in range(ntime):
267 dataloader.AddVariablesArray("vars_time" + str(i), ninput)
268
269
270dataloader.AddSignalTree(signalTree, 1.0)
271dataloader.AddBackgroundTree(background, 1.0)
272
273# check given input
274datainfo = dataloader.GetDataSetInfo()
276print("number of variables is {}".format(vars.size()))
277
278
279for v in vars:
280 print(v)
281
282nTrainSig = 0.8 * nTotEvts
283nTrainBkg = 0.8 * nTotEvts
284
285# Apply additional cuts on the signal and background samples (can be different)
286mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
287mycutb = ""
288
289# build the string options for DataLoader::PrepareTrainingAndTestTree
291 mycuts,
292 mycutb,
293 nTrain_Signal=nTrainSig,
294 nTrain_Background=nTrainBkg,
295 SplitMode="Random",
296 SplitSeed=100,
297 NormMode="NumEvents",
298 V=False,
299 CalcCorrelations=False,
300)
301
302print("prepared DATA LOADER ")
303
304
305## Book TMVA recurrent models
306
307# Book the different types of recurrent models in TMVA (SimpleRNN, LSTM or GRU)
308
309
310if useTMVA_RNN:
311 for i in range(3):
312 if not use_rnn_type[i]:
313 continue
314
315 rnn_type = rnn_types[i]
316
317 ## Define RNN layer layout
318 ## it should be LayerType (RNN or LSTM or GRU) | number of units | number of inputs | time steps | remember output (typically no=0 | return full sequence
319 rnnLayout = str(rnn_type) + "|10|" + str(ninput) + "|" + str(ntime) + "|0|1,RESHAPE|FLAT,DENSE|64|TANH,LINEAR"
320
321 ## Defining Training strategies. Different training strings can be concatenate. Use however only one
322 trainingString1 = "LearningRate=1e-3,Momentum=0.0,Repetitions=1,ConvergenceSteps=5,BatchSize=" + str(batchSize)
323 trainingString1 += ",TestRepetitions=1,WeightDecay=1e-2,Regularization=None,MaxEpochs=" + str(maxepochs)
324 trainingString1 += "Optimizer=ADAM,DropConfig=0.0+0.+0.+0."
325
326 ## define the inputlayout string for RNN
327 ## the input data should be organize as following:
328 ##/ input layout for RNN: time x ndim
329 ## add after RNN a reshape layer (needed top flatten the output) and a dense layer with 64 units and a last one
330 ## Note the last layer is linear because when using Crossentropy a Sigmoid is applied already
331 ## Define the full RNN Noption string adding the final options for all network
332 rnnName = "TMVA_" + str(rnn_type)
334 dataloader,
336 rnnName,
337 H=False,
338 V=True,
339 ErrorStrategy="CROSSENTROPY",
340 VarTransform=None,
341 WeightInitialization="XAVIERUNIFORM",
342 ValidationSize=0.2,
343 RandomSeed=1234,
344 InputLayout=str(ntime) + "|" + str(ninput),
345 Layout=rnnLayout,
346 TrainingStrategy=trainingString1,
347 Architecture=archString
348 )
349
350
351## Book TMVA fully connected dense layer models
352if useTMVA_DNN:
353 # Method DL with Dense Layer
354 # Training strategies.
355 trainingString1 = ROOT.TString(
356 "LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
357 "ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
358 "WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
359 "DropConfig=0.0+0.+0.+0.,Optimizer=ADAM:"
360 ) # + "|" + trainingString2
361 # General Options.
362 trainingString1.Append(archString)
363 dnnName = "TMVA_DNN"
365 dataloader,
367 dnnName,
368 H=False,
369 V=True,
370 ErrorStrategy="CROSSENTROPY",
371 VarTransform=None,
372 WeightInitialization="XAVIER",
373 RandomSeed=0,
374 InputLayout="1|1|" + str(ntime * ninput),
375 Layout="DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR",
376 TrainingStrategy=trainingString1
377 )
378
379
380## Book Keras recurrent models
381
382# Book the different types of recurrent models in Keras (SimpleRNN, LSTM or GRU)
383
384
385if useKeras:
386 for i in range(3):
387 if use_rnn_type[i]:
388 modelName = "model_" + rnn_types[i] + ".keras"
389 trainedModelName = "trained_" + modelName
390 print("Building recurrent keras model using a", rnn_types[i], "layer")
391 # create python script which can be executed
392 # create 2 conv2d layer + maxpool + dense
393 # from keras.initializers import TruncatedNormal
394 # from keras import initializations
395 from tensorflow.keras.layers import (
396 GRU,
397 LSTM,
398 Dense,
399 Flatten,
400 Reshape,
401 SimpleRNN,
402 )
403 from tensorflow.keras.models import Sequential
404 from tensorflow.keras.optimizers import Adam
405
406 model = Sequential()
407 model.add(Reshape((10, 30), input_shape=(10 * 30,)))
408 # add recurrent neural network depending on type / Use option to return the full output
409 if rnn_types[i] == "LSTM":
410 model.add(LSTM(units=10, return_sequences=True))
411 elif rnn_types[i] == "GRU":
412 model.add(GRU(units=10, return_sequences=True))
413 else:
414 model.add(SimpleRNN(units=10, return_sequences=True))
415 # m.AddLine("model.add(BatchNormalization())");
416 model.add(Flatten()) # needed if returning the full time output sequence
417 model.add(Dense(64, activation="tanh"))
418 model.add(Dense(2, activation="sigmoid"))
419 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
420 model.save(modelName)
422 print("saved recurrent model", modelName)
423
424 if not os.path.exists(modelName):
425 useKeras = False
426 print("Error creating Keras recurrent model file - Skip using Keras")
427 else:
428 # book PyKeras method only if Keras model could be created
429 print("Booking Keras model ", rnn_types[i])
431 dataloader,
433 "PyKeras_" + rnn_types[i],
434 H=True,
435 V=False,
436 VarTransform=None,
437 FilenameModel=modelName,
438 FilenameTrainedModel="trained_" + modelName,
439 NumEpochs=maxepochs,
440 BatchSize=batchSize
441 )
442
443
444# use BDT in case not using Keras or TMVA DL
445if not useKeras or not useTMVA_BDT:
446 useTMVA_BDT = True
447
448
449## Book TMVA BDT
450
451
452if useTMVA_BDT:
454 dataloader,
456 "BDTG",
457 H=True,
458 V=False,
459 NTrees=100,
460 MinNodeSize="2.5%",
461 BoostType="Grad",
462 Shrinkage=0.10,
463 UseBaggedBoost=True,
464 BaggedSampleFraction=0.5,
465 nCuts=20,
466 MaxDepth=2,
467 )
468
469
470## Train all methods
472
473print("nthreads = {}".format(ROOT.GetThreadPoolSize()))
474
475# ---- Evaluate all MVAs using the set of test events
477
478# ----- Evaluate and compare performance of all configured MVAs
480
481# check method
482
483# plot ROC curve
484c1 = factory.GetROCCurve(dataloader)
485c1.Draw()
486
487if outputFile:
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:130
This is the main MVA steering class.
Definition Factory.h:80
th1 Draw()