Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
PyTorch_Generate_CNN_Model.py
Go to the documentation of this file.
1import torch
2from torch import nn
3
4# Define model
5
6# Custom Reshape Layer
7class Reshape(torch.nn.Module):
8 def forward(self, x):
9 return x.view(-1,1,16,16)
10
11# CNN Model Definition
12net = torch.nn.Sequential(
13 Reshape(),
14 nn.Conv2d(1, 10, kernel_size=3, padding=1),
15 nn.ReLU(),
16 nn.BatchNorm2d(10),
17 nn.Conv2d(10, 10, kernel_size=3, padding=1),
18 nn.ReLU(),
19 nn.MaxPool2d(kernel_size=2),
20 nn.Flatten(),
21 nn.Linear(10*8*8, 256),
22 nn.ReLU(),
23 nn.Linear(256, 2),
24 nn.Sigmoid()
25 )
26
27# Construct loss function and Optimizer.
28criterion = nn.BCELoss()
29optimizer = torch.optim.Adam
30
31
32def fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
33 trainer = optimizer(model.parameters(), lr=0.01)
34 schedule, schedulerSteps = scheduler
35 best_val = None
36
37 # Setup GPU
38 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39 model = model.to(device)
40
41 for epoch in range(num_epochs):
42 # Training Loop
43 # Set to train mode
44 model.train()
45 running_train_loss = 0.0
46 running_val_loss = 0.0
47 for i, (X, y) in enumerate(train_loader):
48 trainer.zero_grad()
49 X, y = X.to(device), y.to(device)
50 output = model(X)
51 target = y
52 train_loss = criterion(output, target)
53 train_loss.backward()
54 trainer.step()
55
56 # print train statistics
57 running_train_loss += train_loss.item()
58 if i % 4 == 3: # print every 4 mini-batches
59 print(f"[{epoch+1}, {i+1}] train loss: {running_train_loss / 4 :.3f}")
60 running_train_loss = 0.0
61
62 if schedule:
63 schedule(optimizer, epoch, schedulerSteps)
64
65 # Validation Loop
66 # Set to eval mode
67 model.eval()
68 with torch.no_grad():
69 for i, (X, y) in enumerate(val_loader):
70 X, y = X.to(device), y.to(device)
71 output = model(X)
72 target = y
73 val_loss = criterion(output, target)
74 running_val_loss += val_loss.item()
75
76 curr_val = running_val_loss / len(val_loader)
77 if save_best:
78 if best_val==None:
79 best_val = curr_val
80 best_val = save_best(model, curr_val, best_val)
81
82 # print val statistics per epoch
83 print(f"[{epoch+1}] val loss: {curr_val :.3f}")
84 running_val_loss = 0.0
85
86 print(f"Finished Training on {epoch+1} Epochs!")
87
88 return model
89
90
91def predict(model, test_X, batch_size=100):
92 # Set to eval mode
93
94 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
95 model = model.to(device)
96
97 model.eval()
98
99
100 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
101 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
102
103 predictions = []
104 with torch.no_grad():
105 for i, data in enumerate(test_loader):
106 X = data[0].to(device)
107 outputs = model(X)
108 predictions.append(outputs)
109 preds = torch.cat(predictions)
110
111 return preds.cpu().numpy()
112
113
114load_model_custom_objects = {"optimizer": optimizer, "criterion": criterion, "train_func": fit, "predict_func": predict}
115
116# Store model to file
117m = torch.jit.script(net)
118torch.jit.save(m,"PyTorchModelCNN.pt")
predict(model, test_X, batch_size=100)
fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler)