Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_CNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook
4## TMVA Classification Example Using a Convolutional Neural Network
5##
6## This is an example of using a CNN in TMVA. We do classification using a toy image data set
7## that is generated when running the example macro
8##
9## \macro_image
10## \macro_output
11## \macro_code
12##
13## \author Harshal Shende
14
15
16# TMVA Classification Example Using a Convolutional Neural Network
17
18
19## Helper function to create input images data
20## we create a signal and background 2D histograms from 2d gaussians
21## with a location (means in X and Y) different for each event
22## The difference between signal and background is in the gaussian width.
23## The width for the background gaussian is slightly larger than the signal width by few % values
24
25
26import ROOT
27
28#switch off MT in OpenMP (BLAS)
29ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")
30
31TMVA = ROOT.TMVA
32TFile = ROOT.TFile
33
34
35import os
36import importlib
37
39
40def MakeImagesTree(n, nh, nw):
41 # image size (nh x nw)
42 ntot = nh * nw
43 fileOutName = "images_data_16x16.root"
44 nRndmEvts = 10000 # number of events we use to fill each image
45 delta_sigma = 0.1 # 5% difference in the sigma
46 pixelNoise = 5
47
48 sX1 = 3
49 sY1 = 3
50 sX2 = sX1 + delta_sigma
51 sY2 = sY1 - delta_sigma
52 h1 = ROOT.TH2D("h1", "h1", nh, 0, 10, nw, 0, 10)
53 h2 = ROOT.TH2D("h2", "h2", nh, 0, 10, nw, 0, 10)
54 f1 = ROOT.TF2("f1", "xygaus")
55 f2 = ROOT.TF2("f2", "xygaus")
56 sgn = ROOT.TTree("sig_tree", "signal_tree")
57 bkg = ROOT.TTree("bkg_tree", "background_tree")
58
59 f = TFile(fileOutName, "RECREATE")
60 x1 = ROOT.std.vector["float"](ntot)
61 x2 = ROOT.std.vector["float"](ntot)
62
63 # create signal and background trees with a single branch
64 # an std::vector<float> of size nh x nw containing the image data
65 bkg.Branch("vars", "std::vector<float>", x1)
66 sgn.Branch("vars", "std::vector<float>", x2)
67
68 sgn.SetDirectory(f)
69 bkg.SetDirectory(f)
70
71 f1.SetParameters(1, 5, sX1, 5, sY1)
72 f2.SetParameters(1, 5, sX2, 5, sY2)
73 ROOT.gRandom.SetSeed(0)
74 ROOT.Info("TMVA_CNN_Classification", "Filling ROOT tree \n")
75 for i in range(n):
76 if i % 1000 == 0:
77 print("Generating image event ...", i)
78
79 h1.Reset()
80 h2.Reset()
81 # generate random means in range [3,7] to be not too much on the border
82 f1.SetParameter(1, ROOT.gRandom.Uniform(3, 7))
83 f1.SetParameter(3, ROOT.gRandom.Uniform(3, 7))
84 f2.SetParameter(1, ROOT.gRandom.Uniform(3, 7))
85 f2.SetParameter(3, ROOT.gRandom.Uniform(3, 7))
86
87 h1.FillRandom("f1", nRndmEvts)
88 h2.FillRandom("f2", nRndmEvts)
89
90 for k in range(nh):
91 for l in range(nw):
92 m = k * nw + l
93 # add some noise in each bin
94 x1[m] = h1.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
95 x2[m] = h2.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
96
97 sgn.Fill()
98 bkg.Fill()
99
100 sgn.Write()
101 bkg.Write()
102
103 print("Signal and background tree with images data written to the file %s", f.GetName())
104 sgn.Print()
105 bkg.Print()
106 f.Close()
107
108hasGPU = ROOT.gSystem.GetFromPipe("root-config --has-tmva-gpu") == "yes"
109hasCPU = ROOT.gSystem.GetFromPipe("root-config --has-tmva-cpu") == "yes"
110
111nevt = 1000 # use a larger value to get better results
112opt = [1, 1, 1, 1, 1]
113useTMVACNN = opt[0] if len(opt) > 0 else False
114useKerasCNN = opt[1] if len(opt) > 1 else False
115useTMVADNN = opt[2] if len(opt) > 2 else False
116useTMVABDT = opt[3] if len(opt) > 3 else False
117usePyTorchCNN = opt[4] if len(opt) > 4 else False
118
119if (not hasCPU and not hasGPU) :
120 ROOT.Warning("TMVA_CNN_Classificaton","ROOT is not supporting tmva-cpu and tmva-gpu skip using TMVA-DNN and TMVA-CNN")
121 useTMVACNN = False
122 useTMVADNN = False
123
124if ROOT.gSystem.GetFromPipe("root-config --has-tmva-pymva") != "yes":
125 useKerasCNN = False
126 usePyTorchCNN = False
127else:
129
130tf_spec = importlib.util.find_spec("tensorflow")
131if tf_spec is None:
132 useKerasCNN = False
133 ROOT.Warning("TMVA_CNN_Classificaton","Skip using Keras since tensorflow is not installed")
134
135torch_spec = importlib.util.find_spec("torch")
136if torch_spec is None:
137 usePyTorchCNN = False
138 ROOT.Warning("TMVA_CNN_Classificaton","Skip using PyTorch since torch is not installed")
139
140if not useTMVACNN:
141 ROOT.Warning(
142 "TMVA_CNN_Classificaton",
143 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for CNN",
144 )
145
146writeOutputFile = True
147
148num_threads = 4 # use default threads
149max_epochs = 10 # maximum number of epochs used for training
150
151
152# do enable MT running
153if num_threads >= 0:
154 ROOT.EnableImplicitMT(num_threads)
155
156print("Running with nthreads = ", ROOT.GetThreadPoolSize())
157
158
159
160outputFile = None
161if writeOutputFile:
162 outputFile = TFile.Open("TMVA_CNN_ClassificationOutput.root", "RECREATE")
163
164
165## Create TMVA Factory
166
167# Create the Factory class. Later you can choose the methods
168# whose performance you'd like to investigate.
169
170# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass
171
172# - The first argument is the base of the name of all the output
173# weight files in the directory weight/ that will be created with the
174# method parameters
175
176# - The second argument is the output file for the training results
177
178# - The third argument is a string option defining some general configuration for the TMVA session.
179# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in the
180# option string
181
182# - note that we disable any pre-transformation of the input variables and we avoid computing correlations between
183# input variables
184
185
186factory = TMVA.Factory(
187 "TMVA_CNN_Classification",
188 outputFile,
189 V=False,
190 ROC=True,
191 Silent=False,
192 Color=True,
193 AnalysisType="Classification",
194 Transformations=None,
195 Correlations=False,
196)
197
198
199## Declare DataLoader(s)
200
201# The next step is to declare the DataLoader class that deals with input variables
202
203# Define the input variables that shall be used for the MVA training
204# note that you may also use variable expressions, which can be parsed by TTree::Draw( "expression" )]
205
206# In this case the input data consists of an image of 16x16 pixels. Each single pixel is a branch in a ROOT TTree
207
208loader = TMVA.DataLoader("dataset")
209
210
211## Setup Dataset(s)
212
213# Define input data file and signal and background trees
214
215
216imgSize = 16 * 16
217inputFileName = "images_data_16x16.root"
218
219# if the input file does not exist create it
220if ROOT.gSystem.AccessPathName(inputFileName):
221 MakeImagesTree(nevt, 16, 16)
222
223inputFile = TFile.Open(inputFileName)
224if inputFile is None:
225 ROOT.Warning("TMVA_CNN_Classification", "Error opening input file %s - exit", inputFileName.Data())
226
227
228# inputFileName = "tmva_class_example.root"
229
230
231# --- Register the training and test trees
232
233signalTree = inputFile.Get("sig_tree")
234backgroundTree = inputFile.Get("bkg_tree")
235
236nEventsSig = signalTree.GetEntries()
237nEventsBkg = backgroundTree.GetEntries()
238
239# global event weights per tree (see below for setting event-wise weights)
240signalWeight = 1.0
241backgroundWeight = 1.0
242
243# You can add an arbitrary number of signal or background trees
244loader.AddSignalTree(signalTree, signalWeight)
245loader.AddBackgroundTree(backgroundTree, backgroundWeight)
246
247## add event variables (image)
248## use new method (from ROOT 6.20 to add a variable array for all image data)
249loader.AddVariablesArray("vars", imgSize)
250
251# Set individual event weights (the variables must exist in the original TTree)
252# for signal : factory->SetSignalWeightExpression ("weight1*weight2");
253# for background: factory->SetBackgroundWeightExpression("weight1*weight2");
254# loader->SetBackgroundWeightExpression( "weight" );
255
256# Apply additional cuts on the signal and background samples (can be different)
257mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
258mycutb = "" # for example: TCut mycutb = "abs(var1)<0.5";
259
260# Tell the factory how to use the training and testing events
261# If no numbers of events are given, half of the events in the tree are used
262# for training, and the other half for testing:
263# loader.PrepareTrainingAndTestTree( mycut, "SplitMode=random:!V" );
264# It is possible also to specify the number of training and testing events,
265# note we disable the computation of the correlation matrix of the input variables
266
267nTrainSig = 0.8 * nEventsSig
268nTrainBkg = 0.8 * nEventsBkg
269
270# build the string options for DataLoader::PrepareTrainingAndTestTree
271
272loader.PrepareTrainingAndTestTree(
273 mycuts,
274 mycutb,
275 nTrain_Signal=nTrainSig,
276 nTrain_Background=nTrainBkg,
277 SplitMode="Random",
278 SplitSeed=100,
279 NormMode="NumEvents",
280 V=False,
281 CalcCorrelations=False,
282)
283
284
285# DataSetInfo : [dataset] : Added class "Signal"
286# : Add Tree sig_tree of type Signal with 10000 events
287# DataSetInfo : [dataset] : Added class "Background"
288# : Add Tree bkg_tree of type Background with 10000 events
289
290# signalTree.Print();
291
292# Booking Methods
293
294# Here we book the TMVA methods. We book a Boosted Decision Tree method (BDT)
295
296
297# Boosted Decision Trees
298if useTMVABDT:
299 factory.BookMethod(
300 loader,
301 TMVA.Types.kBDT,
302 "BDT",
303 V=False,
304 NTrees=400,
305 MinNodeSize="2.5%",
306 MaxDepth=2,
307 BoostType="AdaBoost",
308 AdaBoostBeta=0.5,
309 UseBaggedBoost=True,
310 BaggedSampleFraction=0.5,
311 SeparationType="GiniIndex",
312 nCuts=20,
313 )
314
315
316#### Booking Deep Neural Network
317
318# Here we book the DNN of TMVA. See the example TMVA_Higgs_Classification.C for a detailed description of the
319# options
320
321if useTMVADNN:
322 layoutString = ROOT.TString(
323 "DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,DENSE|1|LINEAR"
324 )
325
326 # Training strategies
327 # one can catenate several training strings with different parameters (e.g. learning rates or regularizations
328 # parameters) The training string must be concatenated with the `|` delimiter
329 trainingString1 = ROOT.TString(
330 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
331 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
332 "WeightDecay=1e-4,Regularization=None,"
333 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0."
334 ) # + "|" + trainingString2 + ...
335 trainingString1 += ",MaxEpochs=" + str(max_epochs)
336
337 # Build now the full DNN Option string
338 dnnMethodName = "TMVA_DNN_CPU"
339
340 # use GPU if available
341 dnnOptions = "CPU"
342 if hasGPU :
343 dnnOptions = "GPU"
344 dnnMethodName = "TMVA_DNN_GPU"
345
346 factory.BookMethod(
347 loader,
348 TMVA.Types.kDL,
349 dnnMethodName,
350 H=False,
351 V=True,
352 ErrorStrategy="CROSSENTROPY",
353 VarTransform=None,
354 WeightInitialization="XAVIER",
355 Layout=layoutString,
356 TrainingStrategy=trainingString1,
357 Architecture=dnnOptions
358 )
359
360
361### Book Convolutional Neural Network in TMVA
362
363# For building a CNN one needs to define
364
365# - Input Layout : number of channels (in this case = 1) | image height | image width
366# - Batch Layout : batch size | number of channels | image size = (height*width)
367
368# Then one add Convolutional layers and MaxPool layers.
369
370# - For Convolutional layer the option string has to be:
371# - CONV | number of units | filter height | filter width | stride height | stride width | padding height | paddig
372# width | activation function
373
374# - note in this case we are using a filer 3x3 and padding=1 and stride=1 so we get the output dimension of the
375# conv layer equal to the input
376
377# - note we use after the first convolutional layer a batch normalization layer. This seems to help significantly the
378# convergence
379
380# - For the MaxPool layer:
381# - MAXPOOL | pool height | pool width | stride height | stride width
382
383# The RESHAPE layer is needed to flatten the output before the Dense layer
384
385# Note that to run the CNN is required to have CPU or GPU support
386
387
388if useTMVACNN:
389 # Training strategies.
390 trainingString1 = ROOT.TString(
391 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
392 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
393 "WeightDecay=1e-4,Regularization=None,"
394 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0.0"
395 )
396 trainingString1 += ",MaxEpochs=" + str(max_epochs)
397
398 ## New DL (CNN)
399 cnnMethodName = "TMVA_CNN_CPU"
400 cnnOptions = "CPU"
401 # use GPU if available
402 if hasGPU:
403 cnnOptions = "GPU"
404 cnnMethodName = "TMVA_CNN_GPU"
405
406 factory.BookMethod(
407 loader,
408 TMVA.Types.kDL,
409 cnnMethodName,
410 H=False,
411 V=True,
412 ErrorStrategy="CROSSENTROPY",
413 VarTransform=None,
414 WeightInitialization="XAVIER",
415 InputLayout="1|16|16",
416 Layout="CONV|10|3|3|1|1|1|1|RELU,BNORM,CONV|10|3|3|1|1|1|1|RELU,MAXPOOL|2|2|1|1,RESHAPE|FLAT,DENSE|100|RELU,DENSE|1|LINEAR",
417 TrainingStrategy=trainingString1,
418 Architecture=cnnOptions,
419 )
420
421
422### Book Convolutional Neural Network in Keras using a generated model
423
424
425if usePyTorchCNN:
426 ROOT.Info("TMVA_CNN_Classification", "Using Convolutional PyTorch Model")
427 pyTorchFileName = str(ROOT.gROOT.GetTutorialDir())
428 pyTorchFileName += "/tmva/PyTorch_Generate_CNN_Model.py"
429 # check that pytorch can be imported and file defining the model exists
430 torch_spec = importlib.util.find_spec("torch")
431 if torch_spec is not None and os.path.exists(pyTorchFileName):
432 #cmd = str(ROOT.TMVA.Python_Executable()) + " " + pyTorchFileName
433 #os.system(cmd)
434 #import PyTorch_Generate_CNN_Model
435 ROOT.Info("TMVA_CNN_Classification", "Booking PyTorch CNN model")
436 factory.BookMethod(
437 loader,
438 TMVA.Types.kPyTorch,
439 "PyTorch",
440 H=True,
441 V=False,
442 VarTransform=None,
443 FilenameModel="PyTorchModelCNN.pt",
444 FilenameTrainedModel="PyTorchTrainedModelCNN.pt",
445 NumEpochs=max_epochs,
446 BatchSize=100,
447 UserCode=str(pyTorchFileName)
448 )
449 else:
450 ROOT.Warning(
451 "TMVA_CNN_Classification",
452 "PyTorch is not installed or model building file is not existing - skip using PyTorch",
453 )
454
455if useKerasCNN:
456 ROOT.Info("TMVA_CNN_Classification", "Building convolutional keras model")
457 # create python script which can be executed
458 # create 2 conv2d layer + maxpool + dense
459 import tensorflow
460 from tensorflow.keras.models import Sequential
461 from tensorflow.keras.optimizers import Adam
462
463 # from keras.initializers import TruncatedNormal
464 # from keras import initializations
465 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, Reshape
466
467 # from keras.callbacks import ReduceLROnPlateau
468 model = Sequential()
469 model.add(Reshape((16, 16, 1), input_shape=(256,)))
470 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
471 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
472 # stride for maxpool is equal to pool size
473 model.add(MaxPooling2D(pool_size=(2, 2)))
474 model.add(Flatten())
475 model.add(Dense(64, activation="tanh"))
476 # model.add(Dropout(0.2))
477 model.add(Dense(2, activation="sigmoid"))
478 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
479 model.save("model_cnn.h5")
480 model.summary()
481
482 if not os.path.exists("model_cnn.h5"):
483 raise FileNotFoundError("Error creating Keras model file - skip using Keras")
484 else:
485 # book PyKeras method only if Keras model could be created
486 ROOT.Info("TMVA_CNN_Classification", "Booking convolutional keras model")
487 factory.BookMethod(
488 loader,
489 TMVA.Types.kPyKeras,
490 "PyKeras",
491 H=True,
492 V=False,
493 VarTransform=None,
494 FilenameModel="model_cnn.h5",
495 FilenameTrainedModel="trained_model_cnn.h5",
496 NumEpochs=max_epochs,
497 BatchSize=100,
498 GpuOptions="allow_growth=True",
499 ) # needed for RTX NVidia card and to avoid TF allocates all GPU memory
500
501
502
503## Train Methods
504
505factory.TrainAllMethods()
506
507## Test and Evaluate Methods
508
509factory.TestAllMethods()
510
511factory.EvaluateAllMethods()
512
513## Plot ROC Curve
514
515c1 = factory.GetROCCurve(loader)
516c1.Draw()
517
518# close outputfile to save output file
519outputFile.Close()
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4061
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:71
void EnableImplicitMT(UInt_t numthreads=0)
Enable ROOT's implicit multi-threading for all objects and methods that provide an internal paralleli...
Definition TROOT.cxx:527
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.
Definition TROOT.cxx:565