Logo ROOT  
Reference Guide
MethodPyTorch.cxx
Go to the documentation of this file.
1// @(#)root/tmva/pymva $Id$
2// Author: Anirudh Dagar, 2020
3
4#include <Python.h>
6
7#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
8#include <numpy/arrayobject.h>
9
10#include "TMVA/Types.h"
11#include "TMVA/Config.h"
13#include "TMVA/Results.h"
16#include "TMVA/Tools.h"
17#include "TMVA/Timer.h"
18
19using namespace TMVA;
20
21namespace TMVA {
22namespace Internal {
23class PyGILRAII {
24 PyGILState_STATE m_GILState;
25
26public:
27 PyGILRAII() : m_GILState(PyGILState_Ensure()) {}
28 ~PyGILRAII() { PyGILState_Release(m_GILState); }
29};
30} // namespace Internal
31} // namespace TMVA
32
33REGISTER_METHOD(PyTorch)
34
36
37
38MethodPyTorch::MethodPyTorch(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption)
39 : PyMethodBase(jobName, Types::kPyTorch, methodTitle, dsi, theOption) {
40 fNumEpochs = 10;
41 fBatchSize = 100;
42
43 fContinueTraining = false;
44 fSaveBestOnly = true;
45 fLearningRateSchedule = ""; // empty string deactivates learning rate scheduler
46 fFilenameTrainedModel = ""; // empty string sets output model filename to default (in "weights/" directory.)
47}
48
49
50MethodPyTorch::MethodPyTorch(DataSetInfo &theData, const TString &theWeightFile)
51 : PyMethodBase(Types::kPyTorch, theData, theWeightFile) {
52 fNumEpochs = 10;
53 fBatchSize = 100;
54
55 fContinueTraining = false;
56 fSaveBestOnly = true;
57 fLearningRateSchedule = ""; // empty string deactivates learning rate scheduler
58 fFilenameTrainedModel = ""; // empty string sets output model filename to default (in "weights/" directory.)
59}
60
61
63}
64
65
67 if (type == Types::kRegression) return kTRUE;
68 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
69 if (type == Types::kMulticlass && numberClasses >= 2) return kTRUE;
70 return kFALSE;
71}
72
73
75 DeclareOptionRef(fFilenameModel, "FilenameModel", "Filename of the initial PyTorch model");
76 DeclareOptionRef(fFilenameTrainedModel, "FilenameTrainedModel", "Filename of the trained output PyTorch model");
77 DeclareOptionRef(fBatchSize, "BatchSize", "Training batch size");
78 DeclareOptionRef(fNumEpochs, "NumEpochs", "Number of training epochs");
79
80 DeclareOptionRef(fContinueTraining, "ContinueTraining", "Load weights from previous training");
81 DeclareOptionRef(fSaveBestOnly, "SaveBestOnly", "Store only weights with smallest validation loss");
82 DeclareOptionRef(fLearningRateSchedule, "LearningRateSchedule", "Set new learning rate during training at specific epochs, e.g., \"50,0.01;70,0.005\"");
83
84 DeclareOptionRef(fNumValidationString = "20%", "ValidationSize", "Part of the training data to use for validation."
85 "Specify as 0.2 or 20% to use a fifth of the data set as validation set."
86 "Specify as 100 to use exactly 100 events. (Default: 20%)");
87 DeclareOptionRef(fUserCodeName = "", "UserCode", "Necessary python code provided by the user to be executed before loading and training the PyTorch Model");
88
89}
90
91
92////////////////////////////////////////////////////////////////////////////////
93/// Validation of the ValidationSize option. Allowed formats are 20%, 0.2 and
94/// 100 etc.
95/// - 20% and 0.2 selects 20% of the training set as validation data.
96/// - 100 selects 100 events as the validation data.
97///
98/// @return number of samples in validation set
99///
101{
102 Int_t nValidationSamples = 0;
103 UInt_t trainingSetSize = GetEventCollection(Types::kTraining).size();
104
105 // Parsing + Validation
106 // --------------------
107 if (fNumValidationString.EndsWith("%")) {
108 // Relative spec. format 20%
109 TString intValStr = TString(fNumValidationString.Strip(TString::kTrailing, '%'));
110
111 if (intValStr.IsFloat()) {
112 Double_t valSizeAsDouble = fNumValidationString.Atof() / 100.0;
113 nValidationSamples = GetEventCollection(Types::kTraining).size() * valSizeAsDouble;
114 } else {
115 Log() << kFATAL << "Cannot parse number \"" << fNumValidationString
116 << "\". Expected string like \"20%\" or \"20.0%\"." << Endl;
117 }
118 } else if (fNumValidationString.IsFloat()) {
119 Double_t valSizeAsDouble = fNumValidationString.Atof();
120
121 if (valSizeAsDouble < 1.0) {
122 // Relative spec. format 0.2
123 nValidationSamples = GetEventCollection(Types::kTraining).size() * valSizeAsDouble;
124 } else {
125 // Absolute spec format 100 or 100.0
126 nValidationSamples = valSizeAsDouble;
127 }
128 } else {
129 Log() << kFATAL << "Cannot parse number \"" << fNumValidationString << "\". Expected string like \"0.2\" or \"100\"."
130 << Endl;
131 }
132
133 // Value validation
134 // ----------------
135 if (nValidationSamples < 0) {
136 Log() << kFATAL << "Validation size \"" << fNumValidationString << "\" is negative." << Endl;
137 }
138
139 if (nValidationSamples == 0) {
140 Log() << kFATAL << "Validation size \"" << fNumValidationString << "\" is zero." << Endl;
141 }
142
143 if (nValidationSamples >= (Int_t)trainingSetSize) {
144 Log() << kFATAL << "Validation size \"" << fNumValidationString
145 << "\" is larger than or equal in size to training set (size=\"" << trainingSetSize << "\")." << Endl;
146 }
147
148 return nValidationSamples;
149}
150
151
153 // Set default filename for trained model if option is not used
155 fFilenameTrainedModel = GetWeightFileDir() + "/TrainedModel_" + GetName() + ".pt";
156 }
157
158 // - set up number of threads for CPU if NumThreads option was specified
159 // `torch.set_num_threads` sets the number of threads that can be used to
160 // perform cpu operations like conv or mm (usually used by OpenMP or MKL).
161
162 Log() << kINFO << "Using PyTorch - setting special configuration options " << Endl;
163 PyRunString("import torch", "Error importing pytorch");
164
165 // run these above lines also in global namespace to make them visible overall
166 PyRun_String("import torch", Py_single_input, fGlobalNS, fGlobalNS);
167
168 // check pytorch version
169 PyRunString("torch_major_version = int(torch.__version__.split('.')[0])");
170 PyObject *pyTorchVersion = PyDict_GetItemString(fLocalNS, "torch_major_version");
171 int torchVersion = PyLong_AsLong(pyTorchVersion);
172 Log() << kINFO << "Using PyTorch version " << torchVersion << Endl;
173
174 // in case specify number of threads
175 int num_threads = fNumThreads;
176 if (num_threads > 0) {
177 Log() << kINFO << "Setting the CPU number of threads = " << num_threads << Endl;
178
179 PyRunString(TString::Format("torch.set_num_threads(%d)", num_threads));
180 PyRunString(TString::Format("torch.set_num_interop_threads(%d)", num_threads));
181 }
182
183 // Setup model, either the initial model from `fFilenameModel` or
184 // the trained model from `fFilenameTrainedModel`
185 if (fContinueTraining) Log() << kINFO << "Continue training with trained model" << Endl;
187}
188
189
190void MethodPyTorch::SetupPyTorchModel(bool loadTrainedModel) {
191 /*
192 * Load PyTorch model from file
193 */
194
195 Log() << kINFO << " Setup PyTorch Model for training" << Endl;
196
197 if (!fUserCodeName.IsNull()) {
198 Log() << kINFO << " Executing user initialization code from " << fUserCodeName << Endl;
199
200 // run some python code provided by user for method initializations
201 FILE* fp;
202 fp = fopen(fUserCodeName, "r");
203 if (fp) {
204 PyRun_SimpleFile(fp, fUserCodeName);
205 fclose(fp);
206 }
207 else
208 Log() << kFATAL << "Input user code is not existing : " << fUserCodeName << Endl;
209 }
210
211 PyRunString("print('custom objects for loading model : ',load_model_custom_objects)");
212
213 // Setup the training method
214 PyRunString("fit = load_model_custom_objects[\"train_func\"]",
215 "Failed to load train function from file. Please use key: 'train_func' and pass training loop function as the value.");
216 Log() << kINFO << "Loaded pytorch train function: " << Endl;
217
218
219 // Setup Optimizer. Use SGD Optimizer as Default
220 PyRunString("if 'optimizer' in load_model_custom_objects:\n"
221 " optimizer = load_model_custom_objects['optimizer']\n"
222 "else:\n"
223 " optimizer = torch.optim.SGD\n",
224 "Please use key: 'optimizer' and pass a pytorch optimizer as the value for a custom optimizer.");
225 Log() << kINFO << "Loaded pytorch optimizer: " << Endl;
226
227
228 // Setup the loss criterion
229 PyRunString("criterion = load_model_custom_objects[\"criterion\"]",
230 "Failed to load loss function from file. Using MSE Loss as default. Please use key: 'criterion' and pass a pytorch loss function as the value.");
231 Log() << kINFO << "Loaded pytorch loss function: " << Endl;
232
233
234 // Setup the predict method
235 PyRunString("predict = load_model_custom_objects[\"predict_func\"]",
236 "Can't find user predict function object from file. Please use key: 'predict' and pass a predict function for evaluating the model as the value.");
237 Log() << kINFO << "Loaded pytorch predict function: " << Endl;
238
239
240 // Load already trained model or initial model
241 TString filenameLoadModel;
242 if (loadTrainedModel) {
243 filenameLoadModel = fFilenameTrainedModel;
244 }
245 else {
246 filenameLoadModel = fFilenameModel;
247 }
248 PyRunString("model = torch.jit.load('"+filenameLoadModel+"')",
249 "Failed to load PyTorch model from file: "+filenameLoadModel);
250 Log() << kINFO << "Loaded model from file: " << filenameLoadModel << Endl;
251
252
253 /*
254 * Init variables and weights
255 */
256
257 // Get variables, classes and target numbers
261 else Log() << kFATAL << "Selected analysis type is not implemented" << Endl;
262
263 // Init evaluation (needed for getMvaValue)
264 fVals = new float[fNVars]; // holds values used for classification and regression
265 npy_intp dimsVals[2] = {(npy_intp)1, (npy_intp)fNVars};
266 PyArrayObject* pVals = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsVals, NPY_FLOAT, (void*)fVals);
267 PyDict_SetItemString(fLocalNS, "vals", (PyObject*)pVals);
268
269 fOutput.resize(fNOutputs); // holds classification probabilities or regression output
270 npy_intp dimsOutput[2] = {(npy_intp)1, (npy_intp)fNOutputs};
271 PyArrayObject* pOutput = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsOutput, NPY_FLOAT, (void*)&fOutput[0]);
272 PyDict_SetItemString(fLocalNS, "output", (PyObject*)pOutput);
273
274 // Mark the model as setup
275 fModelIsSetup = true;
276}
277
278
280
282
283 if (!PyIsInitialized()) {
284 Log() << kFATAL << "Python is not initialized" << Endl;
285 }
286 _import_array(); // required to use numpy arrays
287
288 // Import PyTorch
289 PyRunString("import sys; sys.argv = ['']", "Set sys.argv failed");
290 PyRunString("import torch", "import PyTorch failed");
291 // do import also in global namespace
292 auto ret = PyRun_String("import torch", Py_single_input, fGlobalNS, fGlobalNS);
293 if (!ret)
294 Log() << kFATAL << "import torch in global namespace failed!" << Endl;
295
296 // Set flag that model is not setup
297 fModelIsSetup = false;
298}
299
300
302 if(!fModelIsSetup) Log() << kFATAL << "Model is not setup for training" << Endl;
303
304 /*
305 * Load training data to numpy array.
306 * NOTE: These are later forced to be converted into torch tensors throught the training loop which may not be the ideal method.
307 */
308
309 UInt_t nAllEvents = Data()->GetNTrainingEvents();
310 UInt_t nValEvents = GetNumValidationSamples();
311 UInt_t nTrainingEvents = nAllEvents - nValEvents;
312
313 Log() << kINFO << "Split TMVA training data in " << nTrainingEvents << " training events and "
314 << nValEvents << " validation events" << Endl;
315
316 float* trainDataX = new float[nTrainingEvents*fNVars];
317 float* trainDataY = new float[nTrainingEvents*fNOutputs];
318 float* trainDataWeights = new float[nTrainingEvents];
319 for (UInt_t i=0; i<nTrainingEvents; i++) {
320 const TMVA::Event* e = GetTrainingEvent(i);
321 // Fill variables
322 for (UInt_t j=0; j<fNVars; j++) {
323 trainDataX[j + i*fNVars] = e->GetValue(j);
324 }
325 // Fill targets
326 // NOTE: For classification, convert class number in one-hot vector,
327 // e.g., 1 -> [0, 1] or 0 -> [1, 0] for binary classification
329 for (UInt_t j=0; j<fNOutputs; j++) {
330 trainDataY[j + i*fNOutputs] = 0;
331 }
332 trainDataY[e->GetClass() + i*fNOutputs] = 1;
333 }
334 else if (GetAnalysisType() == Types::kRegression) {
335 for (UInt_t j=0; j<fNOutputs; j++) {
336 trainDataY[j + i*fNOutputs] = e->GetTarget(j);
337 }
338 }
339 else Log() << kFATAL << "Can not fill target vector because analysis type is not known" << Endl;
340 // Fill weights
341 // NOTE: If no weight branch is given, this defaults to ones for all events
342 trainDataWeights[i] = e->GetWeight();
343 }
344
345 npy_intp dimsTrainX[2] = {(npy_intp)nTrainingEvents, (npy_intp)fNVars};
346 npy_intp dimsTrainY[2] = {(npy_intp)nTrainingEvents, (npy_intp)fNOutputs};
347 npy_intp dimsTrainWeights[1] = {(npy_intp)nTrainingEvents};
348 PyArrayObject* pTrainDataX = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsTrainX, NPY_FLOAT, (void*)trainDataX);
349 PyArrayObject* pTrainDataY = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsTrainY, NPY_FLOAT, (void*)trainDataY);
350 PyArrayObject* pTrainDataWeights = (PyArrayObject*)PyArray_SimpleNewFromData(1, dimsTrainWeights, NPY_FLOAT, (void*)trainDataWeights);
351 PyDict_SetItemString(fLocalNS, "trainX", (PyObject*)pTrainDataX);
352 PyDict_SetItemString(fLocalNS, "trainY", (PyObject*)pTrainDataY);
353 PyDict_SetItemString(fLocalNS, "trainWeights", (PyObject*)pTrainDataWeights);
354
355 /*
356 * Load validation data to numpy array
357 */
358
359 // NOTE: TMVA Validation data is a subset of all the training data
360 // we will not use test data for validation. They will be used for the real testing
361
362
363 float* valDataX = new float[nValEvents*fNVars];
364 float* valDataY = new float[nValEvents*fNOutputs];
365 float* valDataWeights = new float[nValEvents];
366 //validation events follows the trainig one in the TMVA training vector
367 for (UInt_t i=0; i< nValEvents ; i++) {
368 UInt_t ievt = nTrainingEvents + i; // TMVA event index
369 const TMVA::Event* e = GetTrainingEvent(ievt);
370 // Fill variables
371 for (UInt_t j=0; j<fNVars; j++) {
372 valDataX[j + i*fNVars] = e->GetValue(j);
373 }
374 // Fill targets
376 for (UInt_t j=0; j<fNOutputs; j++) {
377 valDataY[j + i*fNOutputs] = 0;
378 }
379 valDataY[e->GetClass() + i*fNOutputs] = 1;
380 }
381 else if (GetAnalysisType() == Types::kRegression) {
382 for (UInt_t j=0; j<fNOutputs; j++) {
383 valDataY[j + i*fNOutputs] = e->GetTarget(j);
384 }
385 }
386 else Log() << kFATAL << "Can not fill target vector because analysis type is not known" << Endl;
387 // Fill weights
388 valDataWeights[i] = e->GetWeight();
389 }
390
391 npy_intp dimsValX[2] = {(npy_intp)nValEvents, (npy_intp)fNVars};
392 npy_intp dimsValY[2] = {(npy_intp)nValEvents, (npy_intp)fNOutputs};
393 npy_intp dimsValWeights[1] = {(npy_intp)nValEvents};
394 PyArrayObject* pValDataX = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsValX, NPY_FLOAT, (void*)valDataX);
395 PyArrayObject* pValDataY = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsValY, NPY_FLOAT, (void*)valDataY);
396 PyArrayObject* pValDataWeights = (PyArrayObject*)PyArray_SimpleNewFromData(1, dimsValWeights, NPY_FLOAT, (void*)valDataWeights);
397 PyDict_SetItemString(fLocalNS, "valX", (PyObject*)pValDataX);
398 PyDict_SetItemString(fLocalNS, "valY", (PyObject*)pValDataY);
399 PyDict_SetItemString(fLocalNS, "valWeights", (PyObject*)pValDataWeights);
400
401 /*
402 * Train PyTorch model
403 */
404 Log() << kINFO << "Print Training Model Architecture" << Endl;
405 PyRunString("print(model)");
406
407 // Setup parameters
408
409 PyObject* pBatchSize = PyLong_FromLong(fBatchSize);
410 PyObject* pNumEpochs = PyLong_FromLong(fNumEpochs);
411 PyDict_SetItemString(fLocalNS, "batchSize", pBatchSize);
412 PyDict_SetItemString(fLocalNS, "numEpochs", pNumEpochs);
413
414 // Prepare PyTorch Training DataSet
415 PyRunString("train_dataset = torch.utils.data.TensorDataset(torch.Tensor(trainX), torch.Tensor(trainY))",
416 "Failed to create pytorch train Dataset.");
417 // Prepare PyTorch Training Dataloader
418 PyRunString("train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchSize, shuffle=False)",
419 "Failed to create pytorch train Dataloader.");
420
421
422 // Prepare PyTorch Validation DataSet
423 PyRunString("val_dataset = torch.utils.data.TensorDataset(torch.Tensor(valX), torch.Tensor(valY))",
424 "Failed to create pytorch validation Dataset.");
425 // Prepare PyTorch validation Dataloader
426 PyRunString("val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize, shuffle=False)",
427 "Failed to create pytorch validation Dataloader.");
428
429
430 // Learning Rate Scheduler
431 if (fLearningRateSchedule!="") {
432 // Setup a python dictionary with the desired learning rate steps
433 PyRunString("strScheduleSteps = '"+fLearningRateSchedule+"'\n"
434 "schedulerSteps = {}\n"
435 "for c in strScheduleSteps.split(';'):\n"
436 " x = c.split(',')\n"
437 " schedulerSteps[int(x[0])] = float(x[1])\n",
438 "Failed to setup steps for scheduler function from string: "+fLearningRateSchedule,
439 Py_file_input);
440 // Set scheduler function as piecewise function with given steps
441 PyRunString("def schedule(optimizer, epoch, schedulerSteps=schedulerSteps):\n"
442 " if epoch in schedulerSteps:\n"
443 " for param_group in optimizer.param_groups:\n"
444 " param_group['lr'] = float(schedulerSteps[epoch])\n",
445 "Failed to setup scheduler function with string: "+fLearningRateSchedule,
446 Py_file_input);
447
448 Log() << kINFO << "Option LearningRateSchedule: Set learning rate during training: " << fLearningRateSchedule << Endl;
449 }
450 else{
451 PyRunString("schedule = None; schedulerSteps = None", "Failed to set scheduler to None.");
452 }
453
454
455 // Save only weights with smallest validation loss
456 if (fSaveBestOnly) {
457 PyRunString("def save_best(model, curr_val, best_val, save_path='"+fFilenameTrainedModel+"'):\n"
458 " if curr_val<=best_val:\n"
459 " best_val = curr_val\n"
460 " best_model_jitted = torch.jit.script(model)\n"
461 " torch.jit.save(best_model_jitted, save_path)\n"
462 " return best_val",
463 "Failed to setup training with option: SaveBestOnly");
464 Log() << kINFO << "Option SaveBestOnly: Only model weights with smallest validation loss will be stored" << Endl;
465 }
466 else{
467 PyRunString("save_best = None", "Failed to set save_best to None.");
468 }
469
470
471 // Note: Early Stopping should not be implemented here. Can be implemented inside train loop function by user if required.
472
473 // Train model
474 PyRunString("trained_model = fit(model, train_loader, val_loader, num_epochs=numEpochs, batch_size=batchSize,"
475 "optimizer=optimizer, criterion=criterion, save_best=save_best, scheduler=(schedule, schedulerSteps))",
476 "Failed to train model");
477
478
479 // Note: PyTorch doesn't store training history data unlike Keras. A user can append and save the loss,
480 // accuracy, other metrics etc to a file for later use.
481
482 /*
483 * Store trained model to file (only if option 'SaveBestOnly' is NOT activated,
484 * because we do not want to override the best model checkpoint)
485 */
486 if (!fSaveBestOnly) {
487 PyRunString("trained_model_jitted = torch.jit.script(trained_model)",
488 "Model not scriptable. Failed to convert to torch script.");
489 PyRunString("torch.jit.save(trained_model_jitted, '"+fFilenameTrainedModel+"')",
490 "Failed to save trained model: "+fFilenameTrainedModel);
491 Log() << kINFO << "Trained model written to file: " << fFilenameTrainedModel << Endl;
492 }
493
494 /*
495 * Clean-up
496 */
497
498 delete[] trainDataX;
499 delete[] trainDataY;
500 delete[] trainDataWeights;
501 delete[] valDataX;
502 delete[] valDataY;
503 delete[] valDataWeights;
504}
505
506
509}
510
512 // Cannot determine error
513 NoErrorCalc(errLower, errUpper);
514
515 // Check whether the model is setup
516 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
517 if (!fModelIsSetup) {
518 // Setup the trained model
519 SetupPyTorchModel(true);
520 }
521
522 // Get signal probability (called mvaValue here)
523 const TMVA::Event* e = GetEvent();
524 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
525 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
526 "Failed to get predictions");
527
528
530}
531
532
533std::vector<Double_t> MethodPyTorch::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress) {
534 // Check whether the model is setup
535 // NOTE: Unfortunately this is needed because during evaluation ProcessOptions is not called again
536 if (!fModelIsSetup) {
537 // Setup the trained model
538 SetupPyTorchModel(true);
539 }
540
541 // Load data to numpy array
542 Long64_t nEvents = Data()->GetNEvents();
543 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
544 if (firstEvt < 0) firstEvt = 0;
545 nEvents = lastEvt-firstEvt;
546
547 // use timer
548 Timer timer( nEvents, GetName(), kTRUE );
549
550 if (logProgress)
551 Log() << kHEADER << Form("[%s] : ",DataInfo().GetName())
552 << "Evaluation of " << GetMethodName() << " on "
553 << (Data()->GetCurrentType() == Types::kTraining ? "training" : "testing")
554 << " sample (" << nEvents << " events)" << Endl;
555
556 float* data = new float[nEvents*fNVars];
557 for (UInt_t i=0; i<nEvents; i++) {
558 Data()->SetCurrentEvent(i);
559 const TMVA::Event *e = GetEvent();
560 for (UInt_t j=0; j<fNVars; j++) {
561 data[j + i*fNVars] = e->GetValue(j);
562 }
563 }
564
565 npy_intp dimsData[2] = {(npy_intp)nEvents, (npy_intp)fNVars};
566 PyArrayObject* pDataMvaValues = (PyArrayObject*)PyArray_SimpleNewFromData(2, dimsData, NPY_FLOAT, (void*)data);
567 if (pDataMvaValues==0) Log() << "Failed to load data to Python array" << Endl;
568
569
570 // Get prediction for all events
571 PyObject* pModel = PyDict_GetItemString(fLocalNS, "model");
572 if (pModel==0) Log() << kFATAL << "Failed to get model Python object" << Endl;
573
574 PyObject* pPredict = PyDict_GetItemString(fLocalNS, "predict");
575 if (pPredict==0) Log() << kFATAL << "Failed to get Python predict function" << Endl;
576
577
578 // Using PyTorch User Defined predict function for predictions
579 PyArrayObject* pPredictions = (PyArrayObject*) PyObject_CallFunctionObjArgs(pPredict, pModel, pDataMvaValues, NULL);
580 if (pPredictions==0) Log() << kFATAL << "Failed to get predictions" << Endl;
581 delete[] data;
582
583 // Load predictions to double vector
584 // NOTE: The signal probability is given at the output
585 std::vector<double> mvaValues(nEvents);
586 float* predictionsData = (float*) PyArray_DATA(pPredictions);
587 for (UInt_t i=0; i<nEvents; i++) {
588 mvaValues[i] = (double) predictionsData[i*fNOutputs + TMVA::Types::kSignal];
589 }
590
591 if (logProgress) {
592 Log() << kINFO
593 << "Elapsed time for evaluation of " << nEvents << " events: "
594 << timer.GetElapsedTime() << " " << Endl;
595 }
596
597 return mvaValues;
598}
599
600std::vector<Float_t>& MethodPyTorch::GetRegressionValues() {
601 // Check whether the model is setup
602 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
603 if (!fModelIsSetup){
604 // Setup the model and load weights
605 SetupPyTorchModel(true);
606 }
607
608 // Get regression values
609 const TMVA::Event* e = GetEvent();
610 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
611
612 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
613 "Failed to get predictions");
614
615
616 // Use inverse transformation of targets to get final regression values
617 Event * eTrans = new Event(*e);
618 for (UInt_t i=0; i<fNOutputs; ++i) {
619 eTrans->SetTarget(i,fOutput[i]);
620 }
621
622 const Event* eTrans2 = GetTransformationHandler().InverseTransform(eTrans);
623 for (UInt_t i=0; i<fNOutputs; ++i) {
624 fOutput[i] = eTrans2->GetTarget(i);
625 }
626
627 return fOutput;
628}
629
630std::vector<Float_t>& MethodPyTorch::GetMulticlassValues() {
631 // Check whether the model is setup
632 // NOTE: unfortunately this is needed because during evaluation ProcessOptions is not called again
633 if (!fModelIsSetup){
634 // Setup the model and load weights
635 SetupPyTorchModel(true);
636 }
637
638 // Get class probabilites
639 const TMVA::Event* e = GetEvent();
640 for (UInt_t i=0; i<fNVars; i++) fVals[i] = e->GetValue(i);
641 PyRunString("for i,p in enumerate(predict(model, vals)): output[i]=p\n",
642 "Failed to get predictions");
643
644 return fOutput;
645}
646
647
649}
650
651
653 Log() << Endl;
654 Log() << "PyTorch is a scientific computing package supporting" << Endl;
655 Log() << "automatic differentiation. This method wraps the training" << Endl;
656 Log() << "and predictions steps of the PyTorch Python package for" << Endl;
657 Log() << "TMVA, so that dataloading, preprocessing and evaluation" << Endl;
658 Log() << "can be done within the TMVA system. To use this PyTorch" << Endl;
659 Log() << "interface, you need to generatea model with PyTorch first." << Endl;
660 Log() << "Then, this model can be loaded and trained in TMVA." << Endl;
661 Log() << Endl;
662}
#define REGISTER_METHOD(CLASS)
for example
_object PyObject
Definition: PyMethodBase.h:43
#define Py_single_input
Definition: PyMethodBase.h:44
#define e(i)
Definition: RSha256.hxx:103
const Bool_t kFALSE
Definition: RtypesCore.h:101
long long Long64_t
Definition: RtypesCore.h:80
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassImp(name)
Definition: Rtypes.h:375
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition: TString.cxx:2468
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Definition: Configurable.h:122
Class that contains all the data information.
Definition: DataSetInfo.h:62
UInt_t GetNClasses() const
Definition: DataSetInfo.h:155
UInt_t GetNTargets() const
Definition: DataSetInfo.h:128
Types::ETreeType GetCurrentType() const
Definition: DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:206
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:68
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:88
void SetTarget(UInt_t itgt, Float_t value)
set the target value (dimension itgt) to value
Definition: Event.cxx:367
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
const char * GetName() const
Definition: MethodBase.h:334
Types::EAnalysisType GetAnalysisType() const
Definition: MethodBase.h:437
const TString & GetWeightFileDir() const
Definition: MethodBase.h:492
const TString & GetMethodName() const
Definition: MethodBase.h:331
const Event * GetEvent() const
Definition: MethodBase.h:751
DataSetInfo & DataInfo() const
Definition: MethodBase.h:410
virtual void TestClassification()
initialization
UInt_t GetNVariables() const
Definition: MethodBase.h:345
TransformationHandler & GetTransformationHandler(Bool_t takeReroutedIfAvailable=true)
Definition: MethodBase.h:394
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
Definition: MethodBase.cxx:836
DataSet * Data() const
Definition: MethodBase.h:409
const Event * GetTrainingEvent(Long64_t ievt) const
Definition: MethodBase.h:771
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
std::vector< Float_t > & GetMulticlassValues()
std::vector< float > fOutput
Definition: MethodPyTorch.h:91
MethodPyTorch(const TString &jobName, const TString &methodTitle, DataSetInfo &dsi, const TString &theOption="")
virtual void TestClassification()
initialization
std::vector< Double_t > GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
get all the MVA values for the events of the current Data type
TString fNumValidationString
Definition: MethodPyTorch.h:85
UInt_t GetNumValidationSamples()
Validation of the ValidationSize option.
void GetHelpMessage() const
TString fLearningRateSchedule
Definition: MethodPyTorch.h:83
std::vector< Float_t > & GetRegressionValues()
TString fFilenameTrainedModel
Definition: MethodPyTorch.h:94
void SetupPyTorchModel(Bool_t loadTrainedModel)
Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper)
static int PyIsInitialized()
Check Python interpreter initialization status.
static PyObject * fGlobalNS
Definition: PyMethodBase.h:136
void PyRunString(TString code, TString errorMessage="Failed to run python code", int start=256)
Execute Python code from string.
PyObject * fLocalNS
Definition: PyMethodBase.h:137
Timing information for training and evaluation of MVA methods.
Definition: Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition: Timer.cxx:146
const Event * InverseTransform(const Event *, Bool_t suppressIfNoTargets=true) const
Singleton class for Global types used by TMVA.
Definition: Types.h:71
@ kSignal
Never change this number - it is elsewhere assumed to be zero !
Definition: Types.h:135
EAnalysisType
Definition: Types.h:126
@ kMulticlass
Definition: Types.h:129
@ kClassification
Definition: Types.h:127
@ kRegression
Definition: Types.h:128
@ kTraining
Definition: Types.h:143
@ kHEADER
Definition: Types.h:63
@ kINFO
Definition: Types.h:58
@ kFATAL
Definition: Types.h:61
Basic string class.
Definition: TString.h:136
Bool_t IsFloat() const
Returns kTRUE if string contains a floating point or integer number.
Definition: TString.cxx:1837
@ kTrailing
Definition: TString.h:267
Bool_t IsNull() const
Definition: TString.h:407
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2357
create variable transformations
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:148
Double_t Log(Double_t x)
Returns the natural logarithm of x.
Definition: TMath.h:754