Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RegressionPyTorch.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_pytorch
4## \notebook -nodraw
5## This tutorial shows how to do regression in TMVA with neural networks
6## trained with PyTorch.
7##
8## \macro_code
9##
10## \date 2020
11## \author Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee
12
13
14from ROOT import TMVA, TFile, TTree, TCut
15from subprocess import call
16from os.path import isfile
17
18import torch
19from torch import nn
20
21
22# Setup TMVA
25
26output = TFile.Open('TMVA.root', 'RECREATE')
27factory = TMVA.Factory('TMVARegression', output,
28 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
29
30
31# Load data
32if not isfile('tmva_reg_example.root'):
33 call(['curl', '-O', 'http://root.cern.ch/files/tmva_reg_example.root'])
34
35data = TFile.Open('tmva_reg_example.root')
36tree = data.Get('TreeR')
37
38dataloader = TMVA.DataLoader('dataset')
39for branch in tree.GetListOfBranches():
40 name = branch.GetName()
41 if name != 'fvalue':
42 dataloader.AddVariable(name)
43dataloader.AddTarget('fvalue')
44
45dataloader.AddRegressionTree(tree, 1.0)
46dataloader.PrepareTrainingAndTestTree(TCut(''),
47 'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')
48
49
50# Generate model
51
52# Define model
53model = nn.Sequential()
54model.add_module('linear_1', nn.Linear(in_features=2, out_features=64))
55model.add_module('relu', nn.Tanh())
56model.add_module('linear_2', nn.Linear(in_features=64, out_features=1))
57
58
59# Construct loss function and Optimizer.
60loss = torch.nn.MSELoss()
61optimizer = torch.optim.SGD
62
63
64# Define train function
65def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
66 trainer = optimizer(model.parameters(), lr=0.01)
67 schedule, schedulerSteps = scheduler
68 best_val = None
69
70 for epoch in range(num_epochs):
71 # Training Loop
72 # Set to train mode
73 model.train()
74 running_train_loss = 0.0
75 running_val_loss = 0.0
76 for i, (X, y) in enumerate(train_loader):
77 trainer.zero_grad()
78 output = model(X)
79 train_loss = criterion(output, y)
80 train_loss.backward()
81 trainer.step()
82
83 # print train statistics
84 running_train_loss += train_loss.item()
85 if i % 32 == 31: # print every 32 mini-batches
86 print("[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
87 running_train_loss = 0.0
88
89 if schedule:
90 schedule(optimizer, epoch, schedulerSteps)
91
92 # Validation Loop
93 # Set to eval mode
94 model.eval()
95 with torch.no_grad():
96 for i, (X, y) in enumerate(val_loader):
97 output = model(X)
98 val_loss = criterion(output, y)
99 running_val_loss += val_loss.item()
100
101 curr_val = running_val_loss / len(val_loader)
102 if save_best:
103 if best_val==None:
104 best_val = curr_val
105 best_val = save_best(model, curr_val, best_val)
106
107 # print val statistics per epoch
108 print("[{}] val loss: {:.3f}".format(epoch+1, curr_val))
109 running_val_loss = 0.0
110
111 print("Finished Training on {} Epochs!".format(epoch+1))
112
113 return model
114
115
116# Define predict function
117def predict(model, test_X, batch_size=32):
118 # Set to eval mode
119 model.eval()
120
121 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
122 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
123
124 predictions = []
125 with torch.no_grad():
126 for i, data in enumerate(test_loader):
127 X = data[0]
128 outputs = model(X)
129 predictions.append(outputs)
130 preds = torch.cat(predictions)
131
132 return preds.numpy()
133
134
135load_model_custom_objects = {"optimizer": optimizer, "criterion": loss, "train_func": train, "predict_func": predict}
136
137
138# Store model to file
139# Convert the model to torchscript before saving
140m = torch.jit.script(model)
141torch.jit.save(m, "model.pt")
142print(m)
143
144
145# Book methods
146factory.BookMethod(dataloader, TMVA.Types.kPyTorch, 'PyTorch',
147 'H:!V:VarTransform=D,G:FilenameModel=model.pt:NumEpochs=20:BatchSize=32')
148factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
149 '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
150
151
152# Run TMVA
153factory.TrainAllMethods()
154factory.TestAllMethods()
155factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
Definition TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:3997
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:75