Logo ROOT  
Reference Guide
PyTorch_Generate_CNN_Model.py
Go to the documentation of this file.
1import torch
2from torch import nn
3
4# Define model
5
6print("running Torch code defining the model....")
7
8# Custom Reshape Layer
9class Reshape(torch.nn.Module):
10 def forward(self, x):
11 return x.view(-1,1,16,16)
12
13# CNN Model Definition
14net = torch.nn.Sequential(
15 Reshape(),
16 nn.Conv2d(1, 10, kernel_size=3, padding=1),
17 nn.ReLU(),
18 nn.BatchNorm2d(10),
19 nn.Conv2d(10, 10, kernel_size=3, padding=1),
20 nn.ReLU(),
21 nn.MaxPool2d(kernel_size=2),
22 nn.Flatten(),
23 nn.Linear(10*8*8, 256),
24 nn.ReLU(),
25 nn.Linear(256, 2),
26 nn.Sigmoid()
27 )
28
29# Construct loss function and Optimizer.
30criterion = nn.BCELoss()
31optimizer = torch.optim.Adam
32
33
34def fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
35 trainer = optimizer(model.parameters(), lr=0.01)
36 schedule, schedulerSteps = scheduler
37 best_val = None
38
39 # Setup GPU
40 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41 model = model.to(device)
42
43 for epoch in range(num_epochs):
44 # Training Loop
45 # Set to train mode
46 model.train()
47 running_train_loss = 0.0
48 running_val_loss = 0.0
49 for i, (X, y) in enumerate(train_loader):
50 trainer.zero_grad()
51 X, y = X.to(device), y.to(device)
52 output = model(X)
53 target = y
54 train_loss = criterion(output, target)
55 train_loss.backward()
56 trainer.step()
57
58 # print train statistics
59 running_train_loss += train_loss.item()
60 if i % 4 == 3: # print every 4 mini-batches
61 print(f"[{epoch+1}, {i+1}] train loss: {running_train_loss / 4 :.3f}")
62 running_train_loss = 0.0
63
64 if schedule:
65 schedule(optimizer, epoch, schedulerSteps)
66
67 # Validation Loop
68 # Set to eval mode
69 model.eval()
70 with torch.no_grad():
71 for i, (X, y) in enumerate(val_loader):
72 X, y = X.to(device), y.to(device)
73 output = model(X)
74 target = y
75 val_loss = criterion(output, target)
76 running_val_loss += val_loss.item()
77
78 curr_val = running_val_loss / len(val_loader)
79 if save_best:
80 if best_val==None:
81 best_val = curr_val
82 best_val = save_best(model, curr_val, best_val)
83
84 # print val statistics per epoch
85 print(f"[{epoch+1}] val loss: {curr_val :.3f}")
86 running_val_loss = 0.0
87
88 print(f"Finished Training on {epoch+1} Epochs!")
89
90 return model
91
92
93def predict(model, test_X, batch_size=100):
94 # Set to eval mode
95
96 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
97 model = model.to(device)
98
99 model.eval()
100
101
102 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
103 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
104
105 predictions = []
106 with torch.no_grad():
107 for i, data in enumerate(test_loader):
108 X = data[0].to(device)
109 outputs = model(X)
110 predictions.append(outputs)
111 preds = torch.cat(predictions)
112
113 return preds.cpu().numpy()
114
115
116load_model_custom_objects = {"optimizer": optimizer, "criterion": criterion, "train_func": fit, "predict_func": predict}
117
118# Store model to file
119m = torch.jit.script(net)
120torch.jit.save(m,"PyTorchModelCNN.pt")
121print("The PyTorch CNN model is created and saved as PyTorchModelCNN.pt")
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
def fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler)
def predict(model, test_X, batch_size=100)