Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MulticlassPyTorch.py
Go to the documentation of this file.
1#!/usr/bin/env python
2## \file
3## \ingroup tutorial_tmva_pytorch
4## \notebook -nodraw
5## This tutorial shows how to do multiclass classification in TMVA with neural
6## networks trained with PyTorch.
7##
8## \macro_code
9##
10## \date 2020
11## \author Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee
12
13
14from ROOT import TMVA, TFile, TTree, TCut, gROOT
15from os.path import isfile
16
17import torch
18from torch import nn
19
20
21# Setup TMVA
24
25# create factory without output file since it is not needed
26factory = TMVA.Factory('TMVAClassification',
27 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
28
29
30# Load data
31if not isfile('tmva_example_multiple_background.root'):
32 createDataMacro = str(gROOT.GetTutorialDir()) + '/tmva/createData.C'
33 print(createDataMacro)
34 gROOT.ProcessLine('.L {}'.format(createDataMacro))
35 gROOT.ProcessLine('create_MultipleBackground(4000)')
36
37data = TFile.Open('tmva_example_multiple_background.root')
38signal = data.Get('TreeS')
39background0 = data.Get('TreeB0')
40background1 = data.Get('TreeB1')
41background2 = data.Get('TreeB2')
42
43dataloader = TMVA.DataLoader('dataset')
44for branch in signal.GetListOfBranches():
45 dataloader.AddVariable(branch.GetName())
46
47dataloader.AddTree(signal, 'Signal')
48dataloader.AddTree(background0, 'Background_0')
49dataloader.AddTree(background1, 'Background_1')
50dataloader.AddTree(background2, 'Background_2')
51dataloader.PrepareTrainingAndTestTree(TCut(''),
52 'SplitMode=Random:NormMode=NumEvents:!V')
53
54
55# Generate model
56# Define model
57model = nn.Sequential()
58model.add_module('linear_1', nn.Linear(in_features=4, out_features=32))
59model.add_module('relu', nn.ReLU())
60model.add_module('linear_2', nn.Linear(in_features=32, out_features=4))
61model.add_module('softmax', nn.Softmax(dim=1))
62
63
64# Set loss and optimizer
65loss = nn.CrossEntropyLoss()
66optimizer = torch.optim.SGD
67
68
69# Define train function
70def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
71 trainer = optimizer(model.parameters(), lr=0.01)
72 schedule, schedulerSteps = scheduler
73 best_val = None
74
75 for epoch in range(num_epochs):
76 # Training Loop
77 # Set to train mode
78 model.train()
79 running_train_loss = 0.0
80 running_val_loss = 0.0
81 for i, (X, y) in enumerate(train_loader):
82 trainer.zero_grad()
83 output = model(X)
84 target = torch.max(y, 1)[1]
85 train_loss = criterion(output, target)
86 train_loss.backward()
87 trainer.step()
88
89 # print train statistics
90 running_train_loss += train_loss.item()
91 if i % 32 == 31: # print every 32 mini-batches
92 print("[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
93 running_train_loss = 0.0
94
95 if schedule:
96 schedule(optimizer, epoch, schedulerSteps)
97
98 # Validation Loop
99 # Set to eval mode
100 model.eval()
101 with torch.no_grad():
102 for i, (X, y) in enumerate(val_loader):
103 output = model(X)
104 target = torch.max(y, 1)[1]
105 val_loss = criterion(output, target)
106 running_val_loss += val_loss.item()
107
108 curr_val = running_val_loss / len(val_loader)
109 if save_best:
110 if best_val==None:
111 best_val = curr_val
112 best_val = save_best(model, curr_val, best_val)
113
114 # print val statistics per epoch
115 print("[{}] val loss: {:.3f}".format(epoch+1, curr_val))
116 running_val_loss = 0.0
117
118 print("Finished Training on {} Epochs!".format(epoch+1))
119
120 return model
121
122
123# Define predict function
124def predict(model, test_X, batch_size=32):
125 # Set to eval mode
126 model.eval()
127
128 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
129 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
130
131 predictions = []
132 with torch.no_grad():
133 for i, data in enumerate(test_loader):
134 X = data[0]
135 outputs = model(X)
136 predictions.append(outputs)
137 preds = torch.cat(predictions)
138
139 return preds.numpy()
140
141
142load_model_custom_objects = {"optimizer": optimizer, "criterion": loss, "train_func": train, "predict_func": predict}
143
144
145# Store model to file
146# Convert the model to torchscript before saving
147m = torch.jit.script(model)
148torch.jit.save(m, "modelMultiClass.pt")
149print(m)
150
151
152# Book methods
153factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
154 '!H:!V:Fisher:VarTransform=D,G')
155factory.BookMethod(dataloader, TMVA.Types.kPyTorch, "PyTorch",
156 'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.pt:FilenameTrainedModel=trainedModelMultiClass.pt:NumEpochs=20:BatchSize=32')
157
158
159# Run TMVA
160factory.TrainAllMethods()
161factory.TestAllMethods()
162factory.EvaluateAllMethods()
163
164# Plot ROC Curves
165roc = factory.GetROCCurve(dataloader)
166roc.SaveAs('ROC_MulticlassPyTorch.png')
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A specialized string object used for TTree selections.
Definition TCut.h:25
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4082
This is the main MVA steering class.
Definition Factory.h:80
static void PyInitialize()
Initialize Python interpreter.
static Tools & Instance()
Definition Tools.cxx:71