Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_RNN_Classification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_tmva
3/// \notebook
4/// TMVA Classification Example Using a Recurrent Neural Network
5///
6/// This is an example of using a RNN in TMVA. We do classification using a toy time dependent data set
7/// that is generated when running this example macro
8///
9/// \macro_image
10/// \macro_output
11/// \macro_code
12///
13/// \author Lorenzo Moneta
14/***
15
16 # TMVA Classification Example Using a Recurrent Neural Network
17
18 This is an example of using a RNN in TMVA.
19 We do the classification using a toy data set containing a time series of data sample ntimes
20 and with dimension ndim that is generated when running the provided function `MakeTimeData (nevents, ntime, ndim)`
21
22
23**/
24
25#include<TROOT.h>
26
27#include "TMVA/Factory.h"
28#include "TMVA/DataLoader.h"
29#include "TMVA/DataSetInfo.h"
30#include "TMVA/Config.h"
31#include "TMVA/MethodDL.h"
32
33
34#include "TFile.h"
35#include "TTree.h"
36
37/// Helper function to generate the time data set
38/// make some time data but not of fixed length.
39/// use a poisson with mu = 5 and troncated at 10
40///
41void MakeTimeData(int n, int ntime, int ndim )
42{
43
44 // const int ntime = 10;
45 // const int ndim = 30; // number of dim/time
46 TString fname = TString::Format("time_data_t%d_d%d.root", ntime, ndim);
47 std::vector<TH1 *> v1(ntime);
48 std::vector<TH1 *> v2(ntime);
49 int i = 0;
50 for (int i = 0; i < ntime; ++i) {
51 v1[i] = new TH1D(TString::Format("h1_%d", i), "h1", ndim, 0, 10);
52 v2[i] = new TH1D(TString::Format("h2_%d", i), "h2", ndim, 0, 10);
53 }
54
55 auto f1 = new TF1("f1", "gaus");
56 auto f2 = new TF1("f2", "gaus");
57
58 TTree sgn("sgn", "sgn");
59 TTree bkg("bkg", "bkg");
60 TFile f(fname, "RECREATE");
61
62 std::vector<std::vector<float>> x1(ntime);
63 std::vector<std::vector<float>> x2(ntime);
64
65 for (int i = 0; i < ntime; ++i) {
66 x1[i] = std::vector<float>(ndim);
67 x2[i] = std::vector<float>(ndim);
68 }
69
70 for (auto i = 0; i < ntime; i++) {
71 bkg.Branch(Form("vars_time%d", i), "std::vector<float>", &x1[i]);
72 sgn.Branch(Form("vars_time%d", i), "std::vector<float>", &x2[i]);
73 }
74
75 sgn.SetDirectory(&f);
76 bkg.SetDirectory(&f);
77 gRandom->SetSeed(0);
78
79 std::vector<double> mean1(ntime);
80 std::vector<double> mean2(ntime);
81 std::vector<double> sigma1(ntime);
82 std::vector<double> sigma2(ntime);
83 for (int j = 0; j < ntime; ++j) {
84 mean1[j] = 5. + 0.2 * sin(TMath::Pi() * j / double(ntime));
85 mean2[j] = 5. + 0.2 * cos(TMath::Pi() * j / double(ntime));
86 sigma1[j] = 4 + 0.3 * sin(TMath::Pi() * j / double(ntime));
87 sigma2[j] = 4 + 0.3 * cos(TMath::Pi() * j / double(ntime));
88 }
89 for (int i = 0; i < n; ++i) {
90
91 if (i % 1000 == 0)
92 std::cout << "Generating event ... " << i << std::endl;
93
94 for (int j = 0; j < ntime; ++j) {
95 auto h1 = v1[j];
96 auto h2 = v2[j];
97 h1->Reset();
98 h2->Reset();
99
100 f1->SetParameters(1, mean1[j], sigma1[j]);
101 f2->SetParameters(1, mean2[j], sigma2[j]);
102
103 h1->FillRandom("f1", 1000);
104 h2->FillRandom("f2", 1000);
105
106 for (int k = 0; k < ndim; ++k) {
107 // std::cout << j*10+k << " ";
108 x1[j][k] = h1->GetBinContent(k + 1) + gRandom->Gaus(0, 10);
109 x2[j][k] = h2->GetBinContent(k + 1) + gRandom->Gaus(0, 10);
110 }
111 }
112 // std::cout << std::endl;
113 sgn.Fill();
114 bkg.Fill();
115
116 if (n == 1) {
117 auto c1 = new TCanvas();
118 c1->Divide(ntime, 2);
119 for (int j = 0; j < ntime; ++j) {
120 c1->cd(j + 1);
121 v1[j]->Draw();
122 }
123 for (int j = 0; j < ntime; ++j) {
124 c1->cd(ntime + j + 1);
125 v2[j]->Draw();
126 }
127 gPad->Update();
128 }
129 }
130 if (n > 1) {
131 sgn.Write();
132 bkg.Write();
133 sgn.Print();
134 bkg.Print();
135 f.Close();
136 }
137}
138/// macro for performing a classification using a Recurrent Neural Network
139/// @param use_type
140/// use_type = 0 use Simple RNN network
141/// use_type = 1 use LSTM network
142/// use_type = 2 use GRU
143/// use_type = 3 build 3 different networks with RNN, LSTM and GRU
144
145void TMVA_RNN_Classification(int use_type = 1)
146{
147
148 const int ninput = 30;
149 const int ntime = 10;
150 const int batchSize = 100;
151 const int maxepochs = 20;
152
153 int nTotEvts = 10000; // total events to be generated for signal or background
154
155 bool useKeras = true;
156
157
158 bool useTMVA_RNN = true;
159 bool useTMVA_DNN = true;
160 bool useTMVA_BDT = false;
161
162 std::vector<std::string> rnn_types = {"RNN", "LSTM", "GRU"};
163 std::vector<bool> use_rnn_type = {1, 1, 1};
164 if (use_type >=0 && use_type < 3) {
165 use_rnn_type = {0,0,0};
166 use_rnn_type[use_type] = 1;
167 }
168 bool useGPU = true; // use GPU for TMVA if available
169
170#ifndef R__HAS_TMVAGPU
171 useGPU = false;
172#ifndef R__HAS_TMVACPU
173 Warning("TMVA_RNN_Classification", "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN");
174 useTMVA_RNN = false;
175#endif
176#endif
177
178
179 TString archString = (useGPU) ? "GPU" : "CPU";
180
181 bool writeOutputFile = true;
182
183
184
185 const char *rnn_type = "RNN";
186
187#ifdef R__HAS_PYMVA
189#else
190 useKeras = false;
191#endif
192
193 int num_threads = 0; // use by default all threads
194 // do enable MT running
195 if (num_threads >= 0) {
196 ROOT::EnableImplicitMT(num_threads);
197 if (num_threads > 0) gSystem->Setenv("OMP_NUM_THREADS", TString::Format("%d",num_threads));
198 }
199 else
200 gSystem->Setenv("OMP_NUM_THREADS", "1");
201
203
204 std::cout << "Running with nthreads = " << ROOT::GetThreadPoolSize() << std::endl;
205
206 TString inputFileName = "time_data_t10_d30.root";
207
208 bool fileExist = !gSystem->AccessPathName(inputFileName);
209
210 // if file does not exists create it
211 if (!fileExist) {
212 MakeTimeData(nTotEvts,ntime, ninput);
213 }
214
215
216 auto inputFile = TFile::Open(inputFileName);
217 if (!inputFile) {
218 Error("TMVA_RNN_Classification", "Error opening input file %s - exit", inputFileName.Data());
219 return;
220 }
221
222
223 std::cout << "--- RNNClassification : Using input file: " << inputFile->GetName() << std::endl;
224
225 // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
226 TString outfileName(TString::Format("data_RNN_%s.root", archString.Data()));
227 TFile *outputFile = nullptr;
228 if (writeOutputFile) outputFile = TFile::Open(outfileName, "RECREATE");
229
230 /**
231 ## Declare Factory
232
233 Create the Factory class. Later you can choose the methods
234 whose performance you'd like to investigate.
235
236 The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to
237pass
238
239 - The first argument is the base of the name of all the output
240 weightfiles in the directory weight/ that will be created with the
241 method parameters
242
243 - The second argument is the output file for the training results
244
245 - The third argument is a string option defining some general configuration for the TMVA session.
246 For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in
247the option string
248
249 **/
250
251 // Creating the factory object
252 TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
253 "!V:!Silent:Color:DrawProgressBar:Transformations=None:!Correlations:"
254 "AnalysisType=Classification:ModelPersistence");
255 TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
256
257 TTree *signalTree = (TTree *)inputFile->Get("sgn");
258 TTree *background = (TTree *)inputFile->Get("bkg");
259
260 const int nvar = ninput * ntime;
261
262 /// add variables - use new AddVariablesArray function
263 for (auto i = 0; i < ntime; i++) {
264 dataloader->AddVariablesArray(Form("vars_time%d", i), ninput);
265 }
266
267 dataloader->AddSignalTree(signalTree, 1.0);
268 dataloader->AddBackgroundTree(background, 1.0);
269
270 // check given input
271 auto &datainfo = dataloader->GetDataSetInfo();
272 auto vars = datainfo.GetListOfVariables();
273 std::cout << "number of variables is " << vars.size() << std::endl;
274 for (auto &v : vars)
275 std::cout << v << ",";
276 std::cout << std::endl;
277
278 int nTrainSig = 0.8 * nTotEvts;
279 int nTrainBkg = 0.8 * nTotEvts;
280
281 // build the string options for DataLoader::PrepareTrainingAndTestTree
282 TString prepareOptions = TString::Format("nTrain_Signal=%d:nTrain_Background=%d:SplitMode=Random:SplitSeed=100:NormMode=NumEvents:!V:!CalcCorrelations", nTrainSig, nTrainBkg);
283
284 // Apply additional cuts on the signal and background samples (can be different)
285 TCut mycuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
286 TCut mycutb = "";
287
288 dataloader->PrepareTrainingAndTestTree(mycuts, mycutb, prepareOptions);
289
290 std::cout << "prepared DATA LOADER " << std::endl;
291
292 /**
293 ## Book TMVA recurrent models
294
295 Book the different types of recurrent models in TMVA (SimpleRNN, LSTM or GRU)
296
297 **/
298
299 if (useTMVA_RNN) {
300
301 for (int i = 0; i < 3; ++i) {
302
303 if (!use_rnn_type[i])
304 continue;
305
306 const char *rnn_type = rnn_types[i].c_str();
307
308 /// define the inputlayout string for RNN
309 /// the input data should be organize as following:
310 //// input layout for RNN: time x ndim
311
312 TString inputLayoutString = TString::Format("InputLayout=%d|%d", ntime, ninput);
313
314 /// Define RNN layer layout
315 /// it should be LayerType (RNN or LSTM or GRU) | number of units | number of inputs | time steps | remember output (typically no=0 | return full sequence
316 TString rnnLayout = TString::Format("%s|10|%d|%d|0|1", rnn_type, ninput, ntime);
317
318 /// add after RNN a reshape layer (needed top flatten the output) and a dense layer with 64 units and a last one
319 /// Note the last layer is linear because when using Crossentropy a Sigmoid is applied already
320 TString layoutString = TString("Layout=") + rnnLayout + TString(",RESHAPE|FLAT,DENSE|64|TANH,LINEAR");
321
322 /// Defining Training strategies. Different training strings can be concatenate. Use however only one
323 TString trainingString1 = TString::Format("LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
324 "ConvergenceSteps=5,BatchSize=%d,TestRepetitions=1,"
325 "WeightDecay=1e-2,Regularization=None,MaxEpochs=%d,"
326 "Optimizer=ADAM,DropConfig=0.0+0.+0.+0.",
327 batchSize,maxepochs);
328
329 TString trainingStrategyString("TrainingStrategy=");
330 trainingStrategyString += trainingString1; // + "|" + trainingString2
331
332 /// Define the full RNN Noption string adding the final options for all network
333 TString rnnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
334 "WeightInitialization=XAVIERUNIFORM:ValidationSize=0.2:RandomSeed=1234");
335
336 rnnOptions.Append(":");
337 rnnOptions.Append(inputLayoutString);
338 rnnOptions.Append(":");
339 rnnOptions.Append(layoutString);
340 rnnOptions.Append(":");
341 rnnOptions.Append(trainingStrategyString);
342 rnnOptions.Append(":");
343 rnnOptions.Append(TString::Format("Architecture=%s", archString.Data()));
344
345 TString rnnName = "TMVA_" + TString(rnn_type);
346 factory->BookMethod(dataloader, TMVA::Types::kDL, rnnName, rnnOptions);
347
348 }
349 }
350
351 /**
352 ## Book TMVA fully connected dense layer models
353
354 **/
355
356 if (useTMVA_DNN) {
357 // Method DL with Dense Layer
358 TString inputLayoutString = TString::Format("InputLayout=1|1|%d", ntime * ninput);
359
360 TString layoutString("Layout=DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR");
361 // Training strategies.
362 TString trainingString1("LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
363 "ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
364 "WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
365 "DropConfig=0.0+0.+0.+0.,Optimizer=ADAM");
366 TString trainingStrategyString("TrainingStrategy=");
367 trainingStrategyString += trainingString1; // + "|" + trainingString2
368
369 // General Options.
370 TString dnnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
371 "WeightInitialization=XAVIER:RandomSeed=0");
372
373 dnnOptions.Append(":");
374 dnnOptions.Append(inputLayoutString);
375 dnnOptions.Append(":");
376 dnnOptions.Append(layoutString);
377 dnnOptions.Append(":");
378 dnnOptions.Append(trainingStrategyString);
379 dnnOptions.Append(":");
380 dnnOptions.Append(archString);
381
382 TString dnnName = "TMVA_DNN";
383 factory->BookMethod(dataloader, TMVA::Types::kDL, dnnName, dnnOptions);
384 }
385
386 /**
387 ## Book Keras recurrent models
388
389 Book the different types of recurrent models in Keras (SimpleRNN, LSTM or GRU)
390
391 **/
392
393 if (useKeras) {
394
395 for (int i = 0; i < 3; i++) {
396
397 if (use_rnn_type[i]) {
398
399 TString modelName = TString::Format("model_%s.h5", rnn_types[i].c_str());
400 TString trainedModelName = TString::Format("trained_model_%s.h5", rnn_types[i].c_str());
401
402 Info("TMVA_RNN_Classification", "Building recurrent keras model using a %s layer", rnn_types[i].c_str());
403 // create python script which can be executed
404 // create 2 conv2d layer + maxpool + dense
405 TMacro m;
406 m.AddLine("import tensorflow");
407 m.AddLine("from tensorflow.keras.models import Sequential");
408 m.AddLine("from tensorflow.keras.optimizers import Adam");
409 m.AddLine("from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, SimpleRNN, GRU, LSTM, Reshape, "
410 "BatchNormalization");
411 m.AddLine("");
412 m.AddLine("model = Sequential() ");
413 m.AddLine("model.add(Reshape((10, 30), input_shape = (10*30, )))");
414 // add recurrent neural network depending on type / Use option to return the full output
415 if (rnn_types[i] == "LSTM")
416 m.AddLine("model.add(LSTM(units=10, return_sequences=True) )");
417 else if (rnn_types[i] == "GRU")
418 m.AddLine("model.add(GRU(units=10, return_sequences=True) )");
419 else
420 m.AddLine("model.add(SimpleRNN(units=10, return_sequences=True) )");
421
422 // m.AddLine("model.add(BatchNormalization())");
423 m.AddLine("model.add(Flatten())"); // needed if returning the full time output sequence
424 m.AddLine("model.add(Dense(64, activation = 'tanh')) ");
425 m.AddLine("model.add(Dense(2, activation = 'sigmoid')) ");
426 m.AddLine(
427 "model.compile(loss = 'binary_crossentropy', optimizer = Adam(lr = 0.001), metrics = ['accuracy'])");
428 m.AddLine(TString::Format("modelName = '%s'", modelName.Data()));
429 m.AddLine("model.save(modelName)");
430 m.AddLine("model.summary()");
431
432 m.SaveSource("make_rnn_model.py");
433 // execute
434 gSystem->Exec("python make_rnn_model.py");
435
436 if (gSystem->AccessPathName(modelName)) {
437 Warning("TMVA_RNN_Classification", "Error creating Keras recurrennt model file - Skip using Keras");
438 useKeras = false;
439 } else {
440 // book PyKeras method only if Keras model could be created
441 Info("TMVA_RNN_Classification", "Booking Keras %s model", rnn_types[i].c_str());
442 factory->BookMethod(dataloader, TMVA::Types::kPyKeras,
443 TString::Format("PyKeras_%s", rnn_types[i].c_str()),
444 TString::Format("!H:!V:VarTransform=None:FilenameModel=%s:tf.keras:"
445 "FilenameTrainedModel=%s:GpuOptions=allow_growth=True:"
446 "NumEpochs=%d:BatchSize=%d",
447 modelName.Data(), trainedModelName.Data(), maxepochs, batchSize));
448 }
449 }
450 }
451 }
452
453 // use BDT in case not using Keras or TMVA DL
454 if (!useKeras || !useTMVA_BDT)
455 useTMVA_BDT = true;
456
457 /**
458 ## Book TMVA BDT
459 **/
460
461 if (useTMVA_BDT) {
462
463 factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTG",
464 "!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:"
465 "BaggedSampleFraction=0.5:nCuts=20:"
466 "MaxDepth=2");
467
468 }
469
470 /// Train all methods
471 factory->TrainAllMethods();
472
473 std::cout << "nthreads = " << ROOT::GetThreadPoolSize() << std::endl;
474
475 // ---- Evaluate all MVAs using the set of test events
476 factory->TestAllMethods();
477
478 // ----- Evaluate and compare performance of all configured MVAs
479 factory->EvaluateAllMethods();
480
481 // check method
482
483 // plot ROC curve
484 auto c1 = factory->GetROCCurve(dataloader);
485 c1->Draw();
486
487 if (outputFile) outputFile->Close();
488}
#define f(i)
Definition RSha256.hxx:104
static const double x2[5]
static const double x1[5]
void Info(const char *location, const char *msgfmt,...)
Use this function for informational messages.
Definition TError.cxx:220
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:187
void Warning(const char *location, const char *msgfmt,...)
Use this function in warning situations.
Definition TError.cxx:231
double cos(double)
double sin(double)
R__EXTERN TRandom * gRandom
Definition TRandom.h:62
char * Form(const char *fmt,...)
R__EXTERN TSystem * gSystem
Definition TSystem.h:559
#define gPad
The Canvas class.
Definition TCanvas.h:23
A specialized string object used for TTree selections.
Definition TCut.h:25
1-Dim function class
Definition TF1.h:213
virtual void SetParameters(const Double_t *params)
Definition TF1.h:644
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition TFile.h:54
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3997
void Close(Option_t *option="") override
Close a file.
Definition TFile.cxx:879
virtual void Draw(Option_t *chopt="")
Draw this graph with its current attributes.
Definition TGraph.cxx:769
1-D histogram with a double per channel (see TH1 documentation)}
Definition TH1.h:618
virtual void Reset(Option_t *option="")
Reset.
Definition TH1.cxx:9947
virtual void FillRandom(const char *fname, Int_t ntimes=5000, TRandom *rng=nullptr)
Fill histogram following distribution in function fname.
Definition TH1.cxx:3525
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition TH1.cxx:4993
static Config & Instance()
static function: returns TMVA instance
Definition Config.cxx:106
void AddVariablesArray(const TString &expression, int size, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating array of variables in data set info in case input tree provides an array ...
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
DataSetInfo & GetDataSetInfo()
std::vector< TString > GetListOfVariables() const
returns list of variables
This is the main MVA steering class.
Definition Factory.h:80
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition Factory.cxx:1114
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition Factory.cxx:352
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition Factory.cxx:1271
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition Factory.cxx:1376
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition Factory.cxx:912
static void PyInitialize()
Initialize Python interpreter.
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
virtual TObjString * AddLine(const char *text)
Add line with text in the list of lines of this macro.
Definition TMacro.cxx:141
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition TRandom.cxx:274
virtual void SetSeed(ULong_t seed=0)
Set the random generator seed.
Definition TRandom.cxx:608
Basic string class.
Definition TString.h:136
const char * Data() const
Definition TString.h:369
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2331
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition TSystem.cxx:654
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1294
virtual void Setenv(const char *name, const char *value)
Set environment variable.
Definition TSystem.cxx:1645
A TTree represents a columnar dataset.
Definition TTree.h:79
return c1
Definition legend1.C:41
const Int_t n
Definition legend1.C:16
TH1F * h1
Definition legend1.C:5
TF1 * f1
Definition legend1.C:11
void EnableImplicitMT(UInt_t numthreads=0)
Enable ROOT's implicit multi-threading for all objects and methods that provide an internal paralleli...
Definition TROOT.cxx:525
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.
Definition TROOT.cxx:563
constexpr Double_t Pi()
Definition TMath.h:37
auto * m
Definition textangle.C:8