ROOT   Reference Guide
Searching...
No Matches
TMVA_CNN_Classification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook
4/// TMVA Classification Example Using a Convolutional Neural Network
5///
6/// This is an example of using a CNN in TMVA. We do classification using a toy image data set
7/// that is generated when running the example macro
8///
9/// \macro_image
10/// \macro_output
11/// \macro_code
12///
13/// \author Lorenzo Moneta
14
15/***
16
17 # TMVA Classification Example Using a Convolutional Neural Network
18
19
20**/
21
22/// Helper function to create input images data
23/// we create a signal and background 2D histograms from 2d gaussians
24/// with a location (means in X and Y) different for each event
25/// The difference between signal and background is in the gaussian width.
26/// The width for the background gaussian is slightly larger than the signal width by few % values
27///
28///
29void MakeImagesTree(int n, int nh, int nw)
30{
31
32 // image size (nh x nw)
33 const int ntot = nh * nw;
34 const TString fileOutName = TString::Format("images_data_%dx%d.root", nh, nw);
35 TFile f(fileOutName, "RECREATE");
36
37 const int nRndmEvts = 10000; // number of events we use to fill each image
38 double delta_sigma = 0.1; // 5% difference in the sigma
39 double pixelNoise = 5;
40
41 double sX1 = 3;
42 double sY1 = 3;
43 double sX2 = sX1 + delta_sigma;
44 double sY2 = sY1 - delta_sigma;
45
46 TH2D h1("h1", "h1", nh, 0, 10, nw, 0, 10);
47 TH2D h2("h2", "h2", nh, 0, 10, nw, 0, 10);
48
49 TF2 f1("f1", "xygaus");
50 TF2 f2("f2", "xygaus");
51
52 TTree sgn("sig_tree", "signal_tree");
53 TTree bkg("bkg_tree", "background_tree");
54
55
56 std::vector<float> x1(ntot);
57 std::vector<float> x2(ntot);
58
59 // create signal and background trees with a single branch
60 // an std::vector<float> of size nh x nw containing the image data
61
62 std::vector<float> *px1 = &x1;
63 std::vector<float> *px2 = &x2;
64
65 bkg.Branch("vars", "std::vector<float>", &px1);
66 sgn.Branch("vars", "std::vector<float>", &px2);
67
68 // std::cout << "create tree " << std::endl;
69
70 sgn.SetDirectory(&f);
71 bkg.SetDirectory(&f);
72
73 f1.SetParameters(1, 5, sX1, 5, sY1);
74 f2.SetParameters(1, 5, sX2, 5, sY2);
75 gRandom->SetSeed(0);
76 std::cout << "Filling ROOT tree " << std::endl;
77 for (int i = 0; i < n; ++i) {
78 if (i % 1000 == 0)
79 std::cout << "Generating image event ... " << i << std::endl;
80 h1.Reset();
81 h2.Reset();
82 // generate random means in range [3,7] to be not too much on the border
83 f1.SetParameter(1, gRandom->Uniform(3, 7));
84 f1.SetParameter(3, gRandom->Uniform(3, 7));
85 f2.SetParameter(1, gRandom->Uniform(3, 7));
86 f2.SetParameter(3, gRandom->Uniform(3, 7));
87
88 h1.FillRandom("f1", nRndmEvts);
89 h2.FillRandom("f2", nRndmEvts);
90
91 for (int k = 0; k < nh; ++k) {
92 for (int l = 0; l < nw; ++l) {
93 int m = k * nw + l;
94 // add some noise in each bin
95 x1[m] = h1.GetBinContent(k + 1, l + 1) + gRandom->Gaus(0, pixelNoise);
96 x2[m] = h2.GetBinContent(k + 1, l + 1) + gRandom->Gaus(0, pixelNoise);
97 }
98 }
99 sgn.Fill();
100 bkg.Fill();
101 }
102 sgn.Write();
103 bkg.Write();
104
105 Info("MakeImagesTree", "Signal and background tree with images data written to the file %s", f.GetName());
106 sgn.Print();
107 bkg.Print();
108 f.Close();
109}
110
111/// @brief Run the TMVA CNN Classification example
112/// @param nevts : number of signal/background events. Use by default a low value (1000)
113/// but increase to at least 5000 to get a good result
114/// @param opt : vector of bool with method used (default all on if available). The order is:
115/// - TMVA CNN
116/// - Keras CNN
117/// - TMVA DNN
118/// - TMVA BDT
119/// - PyTorch CNN
120void TMVA_CNN_Classification(int nevts = 1000, std::vector<bool> opt = {1, 1, 1, 1, 1})
121{
122
123 int imgSize = 16 * 16;
124 TString inputFileName = "images_data_16x16.root";
125
126 bool fileExist = !gSystem->AccessPathName(inputFileName);
127
128 // if file does not exists create it
129 if (!fileExist) {
130 MakeImagesTree(nevts, 16, 16);
131 }
132
133 bool useTMVACNN = (opt.size() > 0) ? opt[0] : false;
134 bool useKerasCNN = (opt.size() > 1) ? opt[1] : false;
135 bool useTMVADNN = (opt.size() > 2) ? opt[2] : false;
136 bool useTMVABDT = (opt.size() > 3) ? opt[3] : false;
137 bool usePyTorchCNN = (opt.size() > 4) ? opt[4] : false;
138#ifndef R__HAS_TMVACPU
139#ifndef R__HAS_TMVAGPU
140 Warning("TMVA_CNN_Classification",
141 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for CNN");
142 useTMVACNN = false;
143#endif
144#endif
145
146 bool writeOutputFile = true;
147
148#ifdef R__USE_IMT
149 int num_threads = 4; // use by default 4 threads if value is not set before
150 // switch off MT in OpenBLAS to avoid conflict with tbb
152
153 // do enable MT running
154 if (num_threads >= 0) {
156 }
157#endif
158
160
161
162 std::cout << "Running with nthreads = " << ROOT::GetThreadPoolSize() << std::endl;
163
164#ifdef R__HAS_PYMVA
165 gSystem->Setenv("KERAS_BACKEND", "tensorflow");
166 // for using Keras
168#else
169 useKerasCNN = false;
170 usePyTorchCNN = false;
171#endif
172
173 TFile *outputFile = nullptr;
174 if (writeOutputFile)
175 outputFile = TFile::Open("TMVA_CNN_ClassificationOutput.root", "RECREATE");
176
177 /***
178 ## Create TMVA Factory
179
180 Create the Factory class. Later you can choose the methods
181 whose performance you'd like to investigate.
182
183 The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass
184
185 - The first argument is the base of the name of all the output
186 weight files in the directory weight/ that will be created with the
187 method parameters
188
189 - The second argument is the output file for the training results
190
191 - The third argument is a string option defining some general configuration for the TMVA session.
192 For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in the
193 option string
194
195 - note that we disable any pre-transformation of the input variables and we avoid computing correlations between
196 input variables
197 ***/
198
199 TMVA::Factory factory(
200 "TMVA_CNN_Classification", outputFile,
201 "!V:ROC:!Silent:Color:AnalysisType=Classification:Transformations=None:!Correlations");
202
203 /***
204
206
207 The next step is to declare the DataLoader class that deals with input variables
208
209 Define the input variables that shall be used for the MVA training
210 note that you may also use variable expressions, which can be parsed by TTree::Draw( "expression" )]
211
212 In this case the input data consists of an image of 16x16 pixels. Each single pixel is a branch in a ROOT TTree
213
214 **/
215
217
218 /***
219
220 ## Setup Dataset(s)
221
222 Define input data file and signal and background trees
223
224 **/
225
226 std::unique_ptr<TFile> inputFile{TFile::Open(inputFileName)};
227 if (!inputFile) {
228 Error("TMVA_CNN_Classification", "Error opening input file %s - exit", inputFileName.Data());
229 return;
230 }
231
232 // --- Register the training and test trees
233
234 auto signalTree = inputFile->Get<TTree>("sig_tree");
235 auto backgroundTree = inputFile->Get<TTree>("bkg_tree");
236
237 if (!signalTree) {
238 Error("TMVA_CNN_Classification", "Could not find signal tree in file '%s'", inputFileName.Data());
239 return;
240 }
241 if (!backgroundTree) {
242 Error("TMVA_CNN_Classification", "Could not find background tree in file '%s'", inputFileName.Data());
243 return;
244 }
245
246 int nEventsSig = signalTree->GetEntries();
247 int nEventsBkg = backgroundTree->GetEntries();
248
249 // global event weights per tree (see below for setting event-wise weights)
250 Double_t signalWeight = 1.0;
251 Double_t backgroundWeight = 1.0;
252
253 // You can add an arbitrary number of signal or background trees
256
257 /// add event variables (image)
258 /// use new method (from ROOT 6.20 to add a variable array for all image data)
260
261 // Set individual event weights (the variables must exist in the original TTree)
262 // for signal : factory->SetSignalWeightExpression ("weight1*weight2");
263 // for background: factory->SetBackgroundWeightExpression("weight1*weight2");
265
266 // Apply additional cuts on the signal and background samples (can be different)
267 TCut mycuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
268 TCut mycutb = ""; // for example: TCut mycutb = "abs(var1)<0.5";
269
270 // Tell the factory how to use the training and testing events
271 //
272 // If no numbers of events are given, half of the events in the tree are used
273 // for training, and the other half for testing:
274 // loader.PrepareTrainingAndTestTree( mycut, "SplitMode=random:!V" );
275 // It is possible also to specify the number of training and testing events,
276 // note we disable the computation of the correlation matrix of the input variables
277
278 int nTrainSig = 0.8 * nEventsSig;
279 int nTrainBkg = 0.8 * nEventsBkg;
280
281 // build the string options for DataLoader::PrepareTrainingAndTestTree
282 TString prepareOptions = TString::Format(
283 "nTrain_Signal=%d:nTrain_Background=%d:SplitMode=Random:SplitSeed=100:NormMode=NumEvents:!V:!CalcCorrelations",
284 nTrainSig, nTrainBkg);
285
287
288 /***
289
290 DataSetInfo : [dataset] : Added class "Signal"
291 : Add Tree sig_tree of type Signal with 10000 events
292 DataSetInfo : [dataset] : Added class "Background"
293 : Add Tree bkg_tree of type Background with 10000 events
294
295
296
297 **/
298
299 /****
300 # Booking Methods
301
302 Here we book the TMVA methods. We book a Boosted Decision Tree method (BDT)
303
304 **/
305
306 // Boosted Decision Trees
307 if (useTMVABDT) {
310 "UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20");
311 }
312 /**
313
314 #### Booking Deep Neural Network
315
316 Here we book the DNN of TMVA. See the example TMVA_Higgs_Classification.C for a detailed description of the
317 options
318
319 **/
320
322
323 TString layoutString(
324 "Layout=DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,DENSE|1|LINEAR");
325
326 // Training strategies
327 // one can catenate several training strings with different parameters (e.g. learning rates or regularizations
328 // parameters) The training string must be concatenates with the | delimiter
329 TString trainingString1("LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
330 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
331 "MaxEpochs=10,WeightDecay=1e-4,Regularization=None,"
333
334 TString trainingStrategyString("TrainingStrategy=");
335 trainingStrategyString += trainingString1; // + "|" + trainingString2 + ....
336
337 // Build now the full DNN Option string
338
339 TString dnnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
340 "WeightInitialization=XAVIER");
341 dnnOptions.Append(":");
342 dnnOptions.Append(layoutString);
343 dnnOptions.Append(":");
344 dnnOptions.Append(trainingStrategyString);
345
346 TString dnnMethodName = "TMVA_DNN_CPU";
347// use GPU if available
348#ifdef R__HAS_TMVAGPU
349 dnnOptions += ":Architecture=GPU";
350 dnnMethodName = "TMVA_DNN_GPU";
351#elif defined(R__HAS_TMVACPU)
352 dnnOptions += ":Architecture=CPU";
353#endif
354
356 }
357
358 /***
359 ### Book Convolutional Neural Network in TMVA
360
361 For building a CNN one needs to define
362
363 - Input Layout : number of channels (in this case = 1) | image height | image width
364 - Batch Layout : batch size | number of channels | image size = (height*width)
365
366 Then one add Convolutional layers and MaxPool layers.
367
368 - For Convolutional layer the option string has to be:
369 - CONV | number of units | filter height | filter width | stride height | stride width | padding height | paddig
370 width | activation function
371
372 - note in this case we are using a filer 3x3 and padding=1 and stride=1 so we get the output dimension of the
373 conv layer equal to the input
374
375 - note we use after the first convolutional layer a batch normalization layer. This seems to help significantly the
376 convergence
377
378 - For the MaxPool layer:
379 - MAXPOOL | pool height | pool width | stride height | stride width
380
381 The RESHAPE layer is needed to flatten the output before the Dense layer
382
383
384 Note that to run the CNN is required to have CPU or GPU support
385
386 ***/
387
388 if (useTMVACNN) {
389
390 TString inputLayoutString("InputLayout=1|16|16");
391
392 // Batch Layout
393 TString layoutString("Layout=CONV|10|3|3|1|1|1|1|RELU,BNORM,CONV|10|3|3|1|1|1|1|RELU,MAXPOOL|2|2|1|1,"
394 "RESHAPE|FLAT,DENSE|100|RELU,DENSE|1|LINEAR");
395
396 // Training strategies.
397 TString trainingString1("LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
398 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
399 "MaxEpochs=10,WeightDecay=1e-4,Regularization=None,"
401
402 TString trainingStrategyString("TrainingStrategy=");
403 trainingStrategyString +=
404 trainingString1; // + "|" + trainingString2 + "|" + trainingString3; for concatenating more training strings
405
406 // Build full CNN Options.
407 TString cnnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
408 "WeightInitialization=XAVIER");
409
410 cnnOptions.Append(":");
411 cnnOptions.Append(inputLayoutString);
412 cnnOptions.Append(":");
413 cnnOptions.Append(layoutString);
414 cnnOptions.Append(":");
415 cnnOptions.Append(trainingStrategyString);
416
417 //// New DL (CNN)
418 TString cnnMethodName = "TMVA_CNN_CPU";
419// use GPU if available
420#ifdef R__HAS_TMVAGPU
421 cnnOptions += ":Architecture=GPU";
422 cnnMethodName = "TMVA_CNN_GPU";
423#else
424 cnnOptions += ":Architecture=CPU";
425 cnnMethodName = "TMVA_CNN_CPU";
426#endif
427
429 }
430
431 /**
432 ### Book Convolutional Neural Network in Keras using a generated model
433
434 **/
435
436#ifdef R__HAS_PYMVA
437 // The next section uses Python packages, execute it only if PyMVA is available
438 TString tmva_python_exe{TMVA::Python_Executable()};
439 TString python_exe = tmva_python_exe.IsNull() ? "python" : tmva_python_exe;
440
441 if (useKerasCNN) {
442
443 Info("TMVA_CNN_Classification", "Building convolutional keras model");
444 // create python script which can be executed
445 // create 2 conv2d layer + maxpool + dense
446 TMacro m;
451 "from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, Reshape, BatchNormalization");
465 m.AddLine("model.compile(loss = 'binary_crossentropy', optimizer = Adam(learning_rate = 0.001), weighted_metrics = ['accuracy'])");
468
469 m.SaveSource("make_cnn_model.py");
470 // execute
471 gSystem->Exec(python_exe + " make_cnn_model.py");
472
473 if (gSystem->AccessPathName("model_cnn.h5")) {
474 Warning("TMVA_CNN_Classification", "Error creating Keras model file - skip using Keras");
475 } else {
476 // book PyKeras method only if Keras model could be created
477 Info("TMVA_CNN_Classification", "Booking tf.Keras CNN model");
478 factory.BookMethod(
480 "H:!V:VarTransform=None:FilenameModel=model_cnn.h5:tf.keras:"
481 "FilenameTrainedModel=trained_model_cnn.h5:NumEpochs=10:BatchSize=100:"
482 "GpuOptions=allow_growth=True"); // needed for RTX NVidia card and to avoid TF allocates all GPU memory
483 }
484 }
485
486 if (usePyTorchCNN) {
487
488 Info("TMVA_CNN_Classification", "Using Convolutional PyTorch Model");
489 TString pyTorchFileName = gROOT->GetTutorialDir() + TString("/tmva/PyTorch_Generate_CNN_Model.py");
490 // check that pytorch can be imported and file defining the model and used later when booking the method is
491 // existing
492 if (gSystem->Exec(python_exe + " -c 'import torch'") || gSystem->AccessPathName(pyTorchFileName)) {
493 Warning("TMVA_CNN_Classification", "PyTorch is not installed or model building file is not existing - skip using PyTorch");
494 } else {
495 // book PyTorch method only if PyTorch model could be created
496 Info("TMVA_CNN_Classification", "Booking PyTorch CNN model");
497 TString methodOpt = "H:!V:VarTransform=None:FilenameModel=PyTorchModelCNN.pt:"
498 "FilenameTrainedModel=PyTorchTrainedModelCNN.pt:NumEpochs=10:BatchSize=100";
499 methodOpt += TString(":UserCode=") + pyTorchFileName;
501 }
502 }
503#endif
504
505 //// ## Train Methods
506
507 factory.TrainAllMethods();
508
509 /// ## Test and Evaluate Methods
510
511 factory.TestAllMethods();
512
513 factory.EvaluateAllMethods();
514
515 /// ## Plot ROC Curve
516
518 c1->Draw();
519
520 // close outputfile to save output file
521 outputFile->Close();
522}
#define f(i)
Definition RSha256.hxx:104
double Double_t
Definition RtypesCore.h:59
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition TError.cxx:218
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:185
void Warning(const char *location, const char *msgfmt,...)
Use this function in warning situations.
Definition TError.cxx:229
Option_t Option_t TPoint TPoint const char x2
Option_t Option_t TPoint TPoint const char x1
#define gROOT
Definition TROOT.h:406
R__EXTERN TRandom * gRandom
Definition TRandom.h:62
R__EXTERN TSystem * gSystem
Definition TSystem.h:560
A specialized string object used for TTree selections.
Definition TCut.h:25
virtual void SetParameters(const Double_t *params)
Definition TF1.h:670
virtual void SetParameter(Int_t param, Double_t value)
Definition TF1.h:660
A 2-Dim function with parameters.
Definition TF2.h:29
A ROOT file is composed of a header, followed by consecutive data records (TKey instances) with a wel...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4070
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:931
void Reset(Option_t *option="") override
Reset.
Definition TH1.cxx:10271
virtual void FillRandom(const char *fname, Int_t ntimes=5000, TRandom *rng=nullptr)
Fill histogram following distribution in function fname.
Definition TH1.cxx:3515
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition TH1.cxx:5025
2-D histogram with a double per channel (see TH1 documentation)
Definition TH2.h:338
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:71
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition TRandom.cxx:275
virtual void SetSeed(ULong_t seed=0)
Set the random generator seed.
Definition TRandom.cxx:615
virtual Double_t Uniform(Double_t x1=1)
Returns a uniform deviate on the interval (0, x1).
Definition TRandom.cxx:682
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:378
Bool_t IsNull() const
Definition TString.h:416
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition TSystem.cxx:641
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1284
virtual void Setenv(const char *name, const char *value)
Set environment variable.
Definition TSystem.cxx:1637
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual Long64_t GetEntries() const
Definition TTree.h:463
return c1
Definition legend1.C:41
const Int_t n
Definition legend1.C:16
TH1F * h1
Definition legend1.C:5
TF1 * f1
Definition legend1.C:11