Logo ROOT  
Reference Guide
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Loading...
Searching...
No Matches
TMVA_CNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva
3## \notebook
4## TMVA Classification Example Using a Convolutional Neural Network
5##
6## This is an example of using a CNN in TMVA. We do classification using a toy image data set
7## that is generated when running the example macro
8##
9## \macro_image
10## \macro_output
11## \macro_code
12##
13## \author Harshal Shende
14
15
16# TMVA Classification Example Using a Convolutional Neural Network
17
18
19## Helper function to create input images data
20## we create a signal and background 2D histograms from 2d gaussians
21## with a location (means in X and Y) different for each event
22## The difference between signal and background is in the gaussian width.
23## The width for the background gaussian is slightly larger than the signal width by few % values
24
25import os
26import importlib.util
27
28opt = [1, 1, 1, 1, 1]
29useTMVACNN = opt[0] if len(opt) > 0 else False
30useKerasCNN = opt[1] if len(opt) > 1 else False
31useTMVADNN = opt[2] if len(opt) > 2 else False
32useTMVABDT = opt[3] if len(opt) > 3 else False
33usePyTorchCNN = opt[4] if len(opt) > 4 else False
34
35tf_spec = importlib.util.find_spec("tensorflow")
36if tf_spec is None:
37 useKerasCNN = False
38 print("TMVA_CNN_Classificaton","Skip using Keras since tensorflow is not installed")
39else:
40 import tensorflow
41
42# PyTorch has to be imported before ROOT to avoid crashes because of clashing
43# std::regexp symbols that are exported by cppyy.
44# See also: https://github.com/wlav/cppyy/issues/227
45torch_spec = importlib.util.find_spec("torch")
46if torch_spec is None:
47 usePyTorchCNN = False
48 print("TMVA_CNN_Classificaton","Skip using PyTorch since torch is not installed")
49else:
50 import torch
51
52
53import ROOT
54
55#switch off MT in OpenMP (BLAS)
56
57TMVA = ROOT.TMVA
58TFile = ROOT.TFile
59
61
62def MakeImagesTree(n, nh, nw):
63 # image size (nh x nw)
64 ntot = nh * nw
65 fileOutName = "images_data_16x16.root"
66 nRndmEvts = 10000 # number of events we use to fill each image
67 delta_sigma = 0.1 # 5% difference in the sigma
68 pixelNoise = 5
69
70 sX1 = 3
71 sY1 = 3
72 sX2 = sX1 + delta_sigma
73 sY2 = sY1 - delta_sigma
74 h1 = ROOT.TH2D("h1", "h1", nh, 0, 10, nw, 0, 10)
75 h2 = ROOT.TH2D("h2", "h2", nh, 0, 10, nw, 0, 10)
76 f1 = ROOT.TF2("f1", "xygaus")
77 f2 = ROOT.TF2("f2", "xygaus")
78 sgn = ROOT.TTree("sig_tree", "signal_tree")
79 bkg = ROOT.TTree("bkg_tree", "background_tree")
80
81 f = TFile(fileOutName, "RECREATE")
82 x1 = ROOT.std.vector["float"](ntot)
83 x2 = ROOT.std.vector["float"](ntot)
84
85 # create signal and background trees with a single branch
86 # an std::vector<float> of size nh x nw containing the image data
87 bkg.Branch("vars", "std::vector<float>", x1)
88 sgn.Branch("vars", "std::vector<float>", x2)
89
90 sgn.SetDirectory(f)
91 bkg.SetDirectory(f)
92
93 f1.SetParameters(1, 5, sX1, 5, sY1)
94 f2.SetParameters(1, 5, sX2, 5, sY2)
95 ROOT.gRandom.SetSeed(0)
96 ROOT.Info("TMVA_CNN_Classification", "Filling ROOT tree \n")
97 for i in range(n):
98 if i % 1000 == 0:
99 print("Generating image event ...", i)
100
101 h1.Reset()
102 h2.Reset()
103 # generate random means in range [3,7] to be not too much on the border
104 f1.SetParameter(1, ROOT.gRandom.Uniform(3, 7))
105 f1.SetParameter(3, ROOT.gRandom.Uniform(3, 7))
106 f2.SetParameter(1, ROOT.gRandom.Uniform(3, 7))
107 f2.SetParameter(3, ROOT.gRandom.Uniform(3, 7))
108
109 h1.FillRandom("f1", nRndmEvts)
110 h2.FillRandom("f2", nRndmEvts)
111
112 for k in range(nh):
113 for l in range(nw):
114 m = k * nw + l
115 # add some noise in each bin
116 x1[m] = h1.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
117 x2[m] = h2.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
118
119 sgn.Fill()
120 bkg.Fill()
121
122 sgn.Write()
123 bkg.Write()
124
125 print("Signal and background tree with images data written to the file %s", f.GetName())
126 sgn.Print()
127 bkg.Print()
128 f.Close()
129
130hasGPU = "tmva-gpu" in ROOT.gROOT.GetConfigFeatures()
131hasCPU = "tmva-cpu" in ROOT.gROOT.GetConfigFeatures()
132
133nevt = 1000 # use a larger value to get better results
134
135if (not hasCPU and not hasGPU) :
136 ROOT.Warning("TMVA_CNN_Classificaton","ROOT is not supporting tmva-cpu and tmva-gpu skip using TMVA-DNN and TMVA-CNN")
137 useTMVACNN = False
138 useTMVADNN = False
139
140if not "tmva-pymva" in ROOT.gROOT.GetConfigFeatures():
141 useKerasCNN = False
142 usePyTorchCNN = False
143else:
145
146if not useTMVACNN:
147 ROOT.Warning(
148 "TMVA_CNN_Classificaton",
149 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for CNN",
150 )
151
152writeOutputFile = True
153
154num_threads = 4 # use default threads
155max_epochs = 10 # maximum number of epochs used for training
156
157
158# do enable MT running
159if "imt" in ROOT.gROOT.GetConfigFeatures():
160 ROOT.EnableImplicitMT(num_threads)
161 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1") # switch OFF MT in OpenBLAS
162 print("Running with nthreads = {}".format(ROOT.GetThreadPoolSize()))
163else:
164 print("Running in serial mode since ROOT does not support MT")
165
166
167
168
169outputFile = None
170if writeOutputFile:
171 outputFile = TFile.Open("TMVA_CNN_ClassificationOutput.root", "RECREATE")
172
173
174## Create TMVA Factory
175
176# Create the Factory class. Later you can choose the methods
177# whose performance you'd like to investigate.
178
179# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass
180
181# - The first argument is the base of the name of all the output
182# weight files in the directory weight/ that will be created with the
183# method parameters
184
185# - The second argument is the output file for the training results
186
187# - The third argument is a string option defining some general configuration for the TMVA session.
188# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in the
189# option string
190
191# - note that we disable any pre-transformation of the input variables and we avoid computing correlations between
192# input variables
193
194
195factory = TMVA.Factory(
196 "TMVA_CNN_Classification",
197 outputFile,
198 V=False,
199 ROC=True,
200 Silent=False,
201 Color=True,
202 AnalysisType="Classification",
203 Transformations=None,
204 Correlations=False,
205)
206
207
208## Declare DataLoader(s)
209
210# The next step is to declare the DataLoader class that deals with input variables
211
212# Define the input variables that shall be used for the MVA training
213# note that you may also use variable expressions, which can be parsed by TTree::Draw( "expression" )]
214
215# In this case the input data consists of an image of 16x16 pixels. Each single pixel is a branch in a ROOT TTree
216
217loader = TMVA.DataLoader("dataset")
218
219
220## Setup Dataset(s)
221
222# Define input data file and signal and background trees
223
224
225imgSize = 16 * 16
226inputFileName = "images_data_16x16.root"
227
228# if the input file does not exist create it
229if ROOT.gSystem.AccessPathName(inputFileName):
230 MakeImagesTree(nevt, 16, 16)
231
232inputFile = TFile.Open(inputFileName)
233if inputFile is None:
234 ROOT.Warning("TMVA_CNN_Classification", "Error opening input file %s - exit", inputFileName.Data())
235
236
237# inputFileName = "tmva_class_example.root"
238
239
240# --- Register the training and test trees
241
242signalTree = inputFile.Get("sig_tree")
243backgroundTree = inputFile.Get("bkg_tree")
244
245nEventsSig = signalTree.GetEntries()
246nEventsBkg = backgroundTree.GetEntries()
247
248# global event weights per tree (see below for setting event-wise weights)
249signalWeight = 1.0
250backgroundWeight = 1.0
251
252# You can add an arbitrary number of signal or background trees
253loader.AddSignalTree(signalTree, signalWeight)
254loader.AddBackgroundTree(backgroundTree, backgroundWeight)
255
256## add event variables (image)
257## use new method (from ROOT 6.20 to add a variable array for all image data)
258loader.AddVariablesArray("vars", imgSize)
259
260# Set individual event weights (the variables must exist in the original TTree)
261# for signal : factory->SetSignalWeightExpression ("weight1*weight2");
262# for background: factory->SetBackgroundWeightExpression("weight1*weight2");
263# loader->SetBackgroundWeightExpression( "weight" );
264
265# Apply additional cuts on the signal and background samples (can be different)
266mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
267mycutb = "" # for example: TCut mycutb = "abs(var1)<0.5";
268
269# Tell the factory how to use the training and testing events
270# If no numbers of events are given, half of the events in the tree are used
271# for training, and the other half for testing:
272# loader.PrepareTrainingAndTestTree( mycut, "SplitMode=random:!V" );
273# It is possible also to specify the number of training and testing events,
274# note we disable the computation of the correlation matrix of the input variables
275
276nTrainSig = 0.8 * nEventsSig
277nTrainBkg = 0.8 * nEventsBkg
278
279# build the string options for DataLoader::PrepareTrainingAndTestTree
280
281loader.PrepareTrainingAndTestTree(
282 mycuts,
283 mycutb,
284 nTrain_Signal=nTrainSig,
285 nTrain_Background=nTrainBkg,
286 SplitMode="Random",
287 SplitSeed=100,
288 NormMode="NumEvents",
289 V=False,
290 CalcCorrelations=False,
291)
292
293
294# DataSetInfo : [dataset] : Added class "Signal"
295# : Add Tree sig_tree of type Signal with 10000 events
296# DataSetInfo : [dataset] : Added class "Background"
297# : Add Tree bkg_tree of type Background with 10000 events
298
299# signalTree.Print();
300
301# Booking Methods
302
303# Here we book the TMVA methods. We book a Boosted Decision Tree method (BDT)
304
305
306# Boosted Decision Trees
307if useTMVABDT:
308 factory.BookMethod(
309 loader,
310 TMVA.Types.kBDT,
311 "BDT",
312 V=False,
313 NTrees=400,
314 MinNodeSize="2.5%",
315 MaxDepth=2,
316 BoostType="AdaBoost",
317 AdaBoostBeta=0.5,
318 UseBaggedBoost=True,
319 BaggedSampleFraction=0.5,
320 SeparationType="GiniIndex",
321 nCuts=20,
322 )
323
324
325#### Booking Deep Neural Network
326
327# Here we book the DNN of TMVA. See the example TMVA_Higgs_Classification.C for a detailed description of the
328# options
329
330if useTMVADNN:
331 layoutString = ROOT.TString(
332 "DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,DENSE|1|LINEAR"
333 )
334
335 # Training strategies
336 # one can catenate several training strings with different parameters (e.g. learning rates or regularizations
337 # parameters) The training string must be concatenated with the `|` delimiter
338 trainingString1 = ROOT.TString(
339 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
340 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
341 "WeightDecay=1e-4,Regularization=None,"
342 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0."
343 ) # + "|" + trainingString2 + ...
344 trainingString1 += ",MaxEpochs=" + str(max_epochs)
345
346 # Build now the full DNN Option string
347 dnnMethodName = "TMVA_DNN_CPU"
348
349 # use GPU if available
350 dnnOptions = "CPU"
351 if hasGPU :
352 dnnOptions = "GPU"
353 dnnMethodName = "TMVA_DNN_GPU"
354
355 factory.BookMethod(
356 loader,
357 TMVA.Types.kDL,
358 dnnMethodName,
359 H=False,
360 V=True,
361 ErrorStrategy="CROSSENTROPY",
362 VarTransform=None,
363 WeightInitialization="XAVIER",
364 Layout=layoutString,
365 TrainingStrategy=trainingString1,
366 Architecture=dnnOptions
367 )
368
369
370### Book Convolutional Neural Network in TMVA
371
372# For building a CNN one needs to define
373
374# - Input Layout : number of channels (in this case = 1) | image height | image width
375# - Batch Layout : batch size | number of channels | image size = (height*width)
376
377# Then one add Convolutional layers and MaxPool layers.
378
379# - For Convolutional layer the option string has to be:
380# - CONV | number of units | filter height | filter width | stride height | stride width | padding height | paddig
381# width | activation function
382
383# - note in this case we are using a filer 3x3 and padding=1 and stride=1 so we get the output dimension of the
384# conv layer equal to the input
385
386# - note we use after the first convolutional layer a batch normalization layer. This seems to help significantly the
387# convergence
388
389# - For the MaxPool layer:
390# - MAXPOOL | pool height | pool width | stride height | stride width
391
392# The RESHAPE layer is needed to flatten the output before the Dense layer
393
394# Note that to run the CNN is required to have CPU or GPU support
395
396
397if useTMVACNN:
398 # Training strategies.
399 trainingString1 = ROOT.TString(
400 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
401 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
402 "WeightDecay=1e-4,Regularization=None,"
403 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0.0"
404 )
405 trainingString1 += ",MaxEpochs=" + str(max_epochs)
406
407 ## New DL (CNN)
408 cnnMethodName = "TMVA_CNN_CPU"
409 cnnOptions = "CPU"
410 # use GPU if available
411 if hasGPU:
412 cnnOptions = "GPU"
413 cnnMethodName = "TMVA_CNN_GPU"
414
415 factory.BookMethod(
416 loader,
417 TMVA.Types.kDL,
418 cnnMethodName,
419 H=False,
420 V=True,
421 ErrorStrategy="CROSSENTROPY",
422 VarTransform=None,
423 WeightInitialization="XAVIER",
424 InputLayout="1|16|16",
425 Layout="CONV|10|3|3|1|1|1|1|RELU,BNORM,CONV|10|3|3|1|1|1|1|RELU,MAXPOOL|2|2|1|1,RESHAPE|FLAT,DENSE|100|RELU,DENSE|1|LINEAR",
426 TrainingStrategy=trainingString1,
427 Architecture=cnnOptions,
428 )
429
430
431### Book Convolutional Neural Network in Keras using a generated model
432
433
434if usePyTorchCNN:
435 ROOT.Info("TMVA_CNN_Classification", "Using Convolutional PyTorch Model")
436 pyTorchFileName = str(ROOT.gROOT.GetTutorialDir())
437 pyTorchFileName += "/tmva/PyTorch_Generate_CNN_Model.py"
438 # check that pytorch can be imported and file defining the model exists
439 torch_spec = importlib.util.find_spec("torch")
440 if torch_spec is not None and os.path.exists(pyTorchFileName):
441 #cmd = str(ROOT.TMVA.Python_Executable()) + " " + pyTorchFileName
442 #os.system(cmd)
443 #import PyTorch_Generate_CNN_Model
444 ROOT.Info("TMVA_CNN_Classification", "Booking PyTorch CNN model")
445 factory.BookMethod(
446 loader,
447 TMVA.Types.kPyTorch,
448 "PyTorch",
449 H=True,
450 V=False,
451 VarTransform=None,
452 FilenameModel="PyTorchModelCNN.pt",
453 FilenameTrainedModel="PyTorchTrainedModelCNN.pt",
454 NumEpochs=max_epochs,
455 BatchSize=100,
456 UserCode=str(pyTorchFileName)
457 )
458 else:
459 ROOT.Warning(
460 "TMVA_CNN_Classification",
461 "PyTorch is not installed or model building file is not existing - skip using PyTorch",
462 )
463
464if useKerasCNN:
465 ROOT.Info("TMVA_CNN_Classification", "Building convolutional keras model")
466 # create python script which can be executed
467 # create 2 conv2d layer + maxpool + dense
468 import tensorflow
469 from tensorflow.keras.models import Sequential
470 from tensorflow.keras.optimizers import Adam
471
472 # from keras.initializers import TruncatedNormal
473 # from keras import initializations
474 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, Reshape
475
476 # from keras.callbacks import ReduceLROnPlateau
477 model = Sequential()
478 model.add(Reshape((16, 16, 1), input_shape=(256,)))
479 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
480 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
481 # stride for maxpool is equal to pool size
482 model.add(MaxPooling2D(pool_size=(2, 2)))
483 model.add(Flatten())
484 model.add(Dense(64, activation="tanh"))
485 # model.add(Dropout(0.2))
486 model.add(Dense(2, activation="sigmoid"))
487 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
488 model.save("model_cnn.h5")
489 model.summary()
490
491 if not os.path.exists("model_cnn.h5"):
492 raise FileNotFoundError("Error creating Keras model file - skip using Keras")
493 else:
494 # book PyKeras method only if Keras model could be created
495 ROOT.Info("TMVA_CNN_Classification", "Booking convolutional keras model")
496 factory.BookMethod(
497 loader,
498 TMVA.Types.kPyKeras,
499 "PyKeras",
500 H=True,
501 V=False,
502 VarTransform=None,
503 FilenameModel="model_cnn.h5",
504 FilenameTrainedModel="trained_model_cnn.h5",
505 NumEpochs=max_epochs,
506 BatchSize=100,
507 GpuOptions="allow_growth=True",
508 ) # needed for RTX NVidia card and to avoid TF allocates all GPU memory
509
510
511
512## Train Methods
513
514factory.TrainAllMethods()
515
516## Test and Evaluate Methods
517
518factory.TestAllMethods()
519
520factory.EvaluateAllMethods()
521
522## Plot ROC Curve
523
524c1 = factory.GetROCCurve(loader)
525c1.Draw()
526
527# close outputfile to save output file
528outputFile.Close()
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:51
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4053
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:71
void EnableImplicitMT(UInt_t numthreads=0)
Enable ROOT's implicit multi-threading for all objects and methods that provide an internal paralleli...
Definition TROOT.cxx:527
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.
Definition TROOT.cxx:565