Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
TMVA_CNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook
4## TMVA Classification Example Using a Convolutional Neural Network
5##
6## This is an example of using a CNN in TMVA. We do classification using a toy image data set
7## that is generated when running the example macro
8##
9## \macro_image
10## \macro_output
11## \macro_code
12##
13## \author Harshal Shende
14
15
16# TMVA Classification Example Using a Convolutional Neural Network
17
18
19## Helper function to create input images data
20## we create a signal and background 2D histograms from 2d gaussians
21## with a location (means in X and Y) different for each event
22## The difference between signal and background is in the gaussian width.
23## The width for the background gaussian is slightly larger than the signal width by few % values
24
25import os
26import importlib.util
27
28opt = [1, 1, 1, 1, 1]
29useTMVACNN = opt[0] if len(opt) > 0 else False
30useKerasCNN = opt[1] if len(opt) > 1 else False
31useTMVADNN = opt[2] if len(opt) > 2 else False
32useTMVABDT = opt[3] if len(opt) > 3 else False
33usePyTorchCNN = opt[4] if len(opt) > 4 else False
34
35tf_spec = importlib.util.find_spec("tensorflow")
36if tf_spec is None:
37 useKerasCNN = False
38 print("TMVA_CNN_Classificaton","Skip using Keras since tensorflow is not installed")
39else:
40 import tensorflow
41
42# PyTorch has to be imported before ROOT to avoid crashes because of clashing
43# std::regexp symbols that are exported by cppyy.
44# See also: https://github.com/wlav/cppyy/issues/227
45torch_spec = importlib.util.find_spec("torch")
46if torch_spec is None:
47 usePyTorchCNN = False
48 print("TMVA_CNN_Classificaton","Skip using PyTorch since torch is not installed")
49else:
50 import torch
51
52
53import ROOT
54
55
56TMVA = ROOT.TMVA
57TFile = ROOT.TFile
58
60
61def MakeImagesTree(n, nh, nw):
62 # image size (nh x nw)
63 ntot = nh * nw
64 fileOutName = "images_data_16x16.root"
65 nRndmEvts = 10000 # number of events we use to fill each image
66 delta_sigma = 0.1 # 5% difference in the sigma
67 pixelNoise = 5
68
69 sX1 = 3
70 sY1 = 3
71 sX2 = sX1 + delta_sigma
72 sY2 = sY1 - delta_sigma
73 h1 = ROOT.TH2D("h1", "h1", nh, 0, 10, nw, 0, 10)
74 h2 = ROOT.TH2D("h2", "h2", nh, 0, 10, nw, 0, 10)
75 f1 = ROOT.TF2("f1", "xygaus")
76 f2 = ROOT.TF2("f2", "xygaus")
77 sgn = ROOT.TTree("sig_tree", "signal_tree")
78 bkg = ROOT.TTree("bkg_tree", "background_tree")
79
80 f = TFile(fileOutName, "RECREATE")
81 x1 = ROOT.std.vector["float"](ntot)
82 x2 = ROOT.std.vector["float"](ntot)
83
84 # create signal and background trees with a single branch
85 # an std::vector<float> of size nh x nw containing the image data
86 bkg.Branch("vars", "std::vector<float>", x1)
87 sgn.Branch("vars", "std::vector<float>", x2)
88
91
92 f1.SetParameters(1, 5, sX1, 5, sY1)
93 f2.SetParameters(1, 5, sX2, 5, sY2)
95 ROOT.Info("TMVA_CNN_Classification", "Filling ROOT tree \n")
96 for i in range(n):
97 if i % 1000 == 0:
98 print("Generating image event ...", i)
99
100 h1.Reset()
101 h2.Reset()
102 # generate random means in range [3,7] to be not too much on the border
107
108 h1.FillRandom("f1", nRndmEvts)
109 h2.FillRandom("f2", nRndmEvts)
110
111 for k in range(nh):
112 for l in range(nw):
113 m = k * nw + l
114 # add some noise in each bin
115 x1[m] = h1.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
116 x2[m] = h2.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
117
118 sgn.Fill()
119 bkg.Fill()
120
121 sgn.Write()
122 bkg.Write()
123
124 print("Signal and background tree with images data written to the file %s", f.GetName())
125 sgn.Print()
126 bkg.Print()
127 f.Close()
128
129hasGPU = "tmva-gpu" in ROOT.gROOT.GetConfigFeatures()
130hasCPU = "tmva-cpu" in ROOT.gROOT.GetConfigFeatures()
131
132nevt = 1000 # use a larger value to get better results
133
134if (not hasCPU and not hasGPU) :
135 ROOT.Warning("TMVA_CNN_Classificaton","ROOT is not supporting tmva-cpu and tmva-gpu skip using TMVA-DNN and TMVA-CNN")
136 useTMVACNN = False
137 useTMVADNN = False
138
139if not "tmva-pymva" in ROOT.gROOT.GetConfigFeatures():
140 useKerasCNN = False
141 usePyTorchCNN = False
142else:
144
145if not useTMVACNN:
147 "TMVA_CNN_Classificaton",
148 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for CNN",
149 )
150
151writeOutputFile = True
152
153num_threads = 4 # use max 4 threads
154max_epochs = 10 # maximum number of epochs used for training
155
156
157# do enable MT running
158if "imt" in ROOT.gROOT.GetConfigFeatures():
159 ROOT.EnableImplicitMT(num_threads)
160 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1") # switch OFF MT in OpenBLAS
161 print("Running with nthreads = {}".format(ROOT.GetThreadPoolSize()))
162else:
163 print("Running in serial mode since ROOT does not support MT")
164
165
166
167
168outputFile = None
169if writeOutputFile:
170 outputFile = TFile.Open("TMVA_CNN_ClassificationOutput.root", "RECREATE")
171
172
173## Create TMVA Factory
174
175# Create the Factory class. Later you can choose the methods
176# whose performance you'd like to investigate.
177
178# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass
179
180# - The first argument is the base of the name of all the output
181# weight files in the directory weight/ that will be created with the
182# method parameters
183
184# - The second argument is the output file for the training results
185
186# - The third argument is a string option defining some general configuration for the TMVA session.
187# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in the
188# option string
189
190# - note that we disable any pre-transformation of the input variables and we avoid computing correlations between
191# input variables
192
193
194factory = TMVA.Factory(
195 "TMVA_CNN_Classification",
196 outputFile,
197 V=False,
198 ROC=True,
199 Silent=False,
200 Color=True,
201 AnalysisType="Classification",
202 Transformations=None,
203 Correlations=False,
204)
205
206
207## Declare DataLoader(s)
208
209# The next step is to declare the DataLoader class that deals with input variables
210
211# Define the input variables that shall be used for the MVA training
212# note that you may also use variable expressions, which can be parsed by TTree::Draw( "expression" )]
213
214# In this case the input data consists of an image of 16x16 pixels. Each single pixel is a branch in a ROOT TTree
215
216loader = TMVA.DataLoader("dataset")
217
218
219## Setup Dataset(s)
220
221# Define input data file and signal and background trees
222
223
224imgSize = 16 * 16
225inputFileName = "images_data_16x16.root"
226
227# if the input file does not exist create it
228if ROOT.gSystem.AccessPathName(inputFileName):
229 MakeImagesTree(nevt, 16, 16)
230
231inputFile = TFile.Open(inputFileName)
232if inputFile is None:
233 ROOT.Warning("TMVA_CNN_Classification", "Error opening input file %s - exit", inputFileName.Data())
234
235
236# inputFileName = "tmva_class_example.root"
237
238
239# --- Register the training and test trees
240
241signalTree = inputFile.Get("sig_tree")
242backgroundTree = inputFile.Get("bkg_tree")
243
244nEventsSig = signalTree.GetEntries()
245nEventsBkg = backgroundTree.GetEntries()
246
247# global event weights per tree (see below for setting event-wise weights)
248signalWeight = 1.0
249backgroundWeight = 1.0
250
251# You can add an arbitrary number of signal or background trees
252loader.AddSignalTree(signalTree, signalWeight)
253loader.AddBackgroundTree(backgroundTree, backgroundWeight)
254
255## add event variables (image)
256## use new method (from ROOT 6.20 to add a variable array for all image data)
257loader.AddVariablesArray("vars", imgSize)
258
259# Set individual event weights (the variables must exist in the original TTree)
260# for signal : factory->SetSignalWeightExpression ("weight1*weight2");
261# for background: factory->SetBackgroundWeightExpression("weight1*weight2");
262# loader->SetBackgroundWeightExpression( "weight" );
263
264# Apply additional cuts on the signal and background samples (can be different)
265mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
266mycutb = "" # for example: TCut mycutb = "abs(var1)<0.5";
267
268# Tell the factory how to use the training and testing events
269# If no numbers of events are given, half of the events in the tree are used
270# for training, and the other half for testing:
271# loader.PrepareTrainingAndTestTree( mycut, "SplitMode=random:!V" );
272# It is possible also to specify the number of training and testing events,
273# note we disable the computation of the correlation matrix of the input variables
274
275nTrainSig = 0.8 * nEventsSig
276nTrainBkg = 0.8 * nEventsBkg
277
278# build the string options for DataLoader::PrepareTrainingAndTestTree
279
281 mycuts,
282 mycutb,
283 nTrain_Signal=nTrainSig,
284 nTrain_Background=nTrainBkg,
285 SplitMode="Random",
286 SplitSeed=100,
287 NormMode="NumEvents",
288 V=False,
289 CalcCorrelations=False,
290)
291
292
293# DataSetInfo : [dataset] : Added class "Signal"
294# : Add Tree sig_tree of type Signal with 10000 events
295# DataSetInfo : [dataset] : Added class "Background"
296# : Add Tree bkg_tree of type Background with 10000 events
297
298# signalTree.Print();
299
300# Booking Methods
301
302# Here we book the TMVA methods. We book a Boosted Decision Tree method (BDT)
303
304
305# Boosted Decision Trees
306if useTMVABDT:
308 loader,
310 "BDT",
311 V=False,
312 NTrees=400,
313 MinNodeSize="2.5%",
314 MaxDepth=2,
315 BoostType="AdaBoost",
316 AdaBoostBeta=0.5,
317 UseBaggedBoost=True,
318 BaggedSampleFraction=0.5,
319 SeparationType="GiniIndex",
320 nCuts=20,
321 )
322
323
324#### Booking Deep Neural Network
325
326# Here we book the DNN of TMVA. See the example TMVA_Higgs_Classification.C for a detailed description of the
327# options
328
329if useTMVADNN:
330 layoutString = ROOT.TString(
331 "DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,DENSE|1|LINEAR"
332 )
333
334 # Training strategies
335 # one can catenate several training strings with different parameters (e.g. learning rates or regularizations
336 # parameters) The training string must be concatenated with the `|` delimiter
337 trainingString1 = ROOT.TString(
338 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
339 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
340 "WeightDecay=1e-4,Regularization=None,"
341 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0."
342 ) # + "|" + trainingString2 + ...
343 trainingString1 += ",MaxEpochs=" + str(max_epochs)
344
345 # Build now the full DNN Option string
346 dnnMethodName = "TMVA_DNN_CPU"
347
348 # use GPU if available
349 dnnOptions = "CPU"
350 if hasGPU :
351 dnnOptions = "GPU"
352 dnnMethodName = "TMVA_DNN_GPU"
353
355 loader,
357 dnnMethodName,
358 H=False,
359 V=True,
360 ErrorStrategy="CROSSENTROPY",
361 VarTransform=None,
362 WeightInitialization="XAVIER",
363 Layout=layoutString,
364 TrainingStrategy=trainingString1,
365 Architecture=dnnOptions
366 )
367
368
369### Book Convolutional Neural Network in TMVA
370
371# For building a CNN one needs to define
372
373# - Input Layout : number of channels (in this case = 1) | image height | image width
374# - Batch Layout : batch size | number of channels | image size = (height*width)
375
376# Then one add Convolutional layers and MaxPool layers.
377
378# - For Convolutional layer the option string has to be:
379# - CONV | number of units | filter height | filter width | stride height | stride width | padding height | paddig
380# width | activation function
381
382# - note in this case we are using a filer 3x3 and padding=1 and stride=1 so we get the output dimension of the
383# conv layer equal to the input
384
385# - note we use after the first convolutional layer a batch normalization layer. This seems to help significantly the
386# convergence
387
388# - For the MaxPool layer:
389# - MAXPOOL | pool height | pool width | stride height | stride width
390
391# The RESHAPE layer is needed to flatten the output before the Dense layer
392
393# Note that to run the CNN is required to have CPU or GPU support
394
395
396if useTMVACNN:
397 # Training strategies.
398 trainingString1 = ROOT.TString(
399 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
400 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
401 "WeightDecay=1e-4,Regularization=None,"
402 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0.0"
403 )
404 trainingString1 += ",MaxEpochs=" + str(max_epochs)
405
406 ## New DL (CNN)
407 cnnMethodName = "TMVA_CNN_CPU"
408 cnnOptions = "CPU"
409 # use GPU if available
410 if hasGPU:
411 cnnOptions = "GPU"
412 cnnMethodName = "TMVA_CNN_GPU"
413
415 loader,
417 cnnMethodName,
418 H=False,
419 V=True,
420 ErrorStrategy="CROSSENTROPY",
421 VarTransform=None,
422 WeightInitialization="XAVIER",
423 InputLayout="1|16|16",
424 Layout="CONV|10|3|3|1|1|1|1|RELU,BNORM,CONV|10|3|3|1|1|1|1|RELU,MAXPOOL|2|2|1|1,RESHAPE|FLAT,DENSE|100|RELU,DENSE|1|LINEAR",
425 TrainingStrategy=trainingString1,
426 Architecture=cnnOptions,
427 )
428
429
430### Book Convolutional Neural Network in Keras using a generated model
431
432
433if usePyTorchCNN:
434 ROOT.Info("TMVA_CNN_Classification", "Using Convolutional PyTorch Model")
435 pyTorchFileName = str(ROOT.gROOT.GetTutorialDir())
436 pyTorchFileName += "/tmva/PyTorch_Generate_CNN_Model.py"
437 # check that pytorch can be imported and file defining the model exists
438 torch_spec = importlib.util.find_spec("torch")
439 if torch_spec is not None and os.path.exists(pyTorchFileName):
440 #cmd = str(ROOT.TMVA.Python_Executable()) + " " + pyTorchFileName
441 #os.system(cmd)
442 #import PyTorch_Generate_CNN_Model
443 ROOT.Info("TMVA_CNN_Classification", "Booking PyTorch CNN model")
445 loader,
447 "PyTorch",
448 H=True,
449 V=False,
450 VarTransform=None,
451 FilenameModel="PyTorchModelCNN.pt",
452 FilenameTrainedModel="PyTorchTrainedModelCNN.pt",
453 NumEpochs=max_epochs,
454 BatchSize=100,
455 UserCode=str(pyTorchFileName)
456 )
457 else:
459 "TMVA_CNN_Classification",
460 "PyTorch is not installed or model building file is not existing - skip using PyTorch",
461 )
462
463if useKerasCNN:
464 ROOT.Info("TMVA_CNN_Classification", "Building convolutional keras model")
465 # create python script which can be executed
466 # create 2 conv2d layer + maxpool + dense
467 import tensorflow
468 from tensorflow.keras.models import Sequential
469 from tensorflow.keras.optimizers import Adam
470
471 # from keras.initializers import TruncatedNormal
472 # from keras import initializations
473 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, Reshape
474
475 # from keras.callbacks import ReduceLROnPlateau
476 model = Sequential()
477 model.add(Reshape((16, 16, 1), input_shape=(256,)))
478 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
479 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
480 # stride for maxpool is equal to pool size
481 model.add(MaxPooling2D(pool_size=(2, 2)))
482 model.add(Flatten())
483 model.add(Dense(64, activation="tanh"))
484 # model.add(Dropout(0.2))
485 model.add(Dense(2, activation="sigmoid"))
486 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
487 model.save("model_cnn.h5")
489
490 if not os.path.exists("model_cnn.h5"):
491 raise FileNotFoundError("Error creating Keras model file - skip using Keras")
492 else:
493 # book PyKeras method only if Keras model could be created
494 ROOT.Info("TMVA_CNN_Classification", "Booking convolutional keras model")
496 loader,
498 "PyKeras",
499 H=True,
500 V=False,
501 VarTransform=None,
502 FilenameModel="model_cnn.h5",
503 FilenameTrainedModel="trained_model_cnn.h5",
504 NumEpochs=max_epochs,
505 BatchSize=100,
506 GpuOptions="allow_growth=True",
507 ) # needed for RTX NVidia card and to avoid TF allocates all GPU memory
508
509
510
511## Train Methods
512
514
515## Test and Evaluate Methods
516
518
520
521## Plot ROC Curve
522
523c1 = factory.GetROCCurve(loader)
524c1.Draw()
525
526# close outputfile to save output file
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
This is the main MVA steering class.
Definition Factory.h:80