Logo ROOT  
Reference Guide
PyTorch_Generate_CNN_Model Namespace Reference

Classes

class  Reshape
 

Functions

def fit (model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler)
  More...
 
def predict (model, test_X, batch_size=100)
  More...
 

Variables

 criterion = nn.BCELoss()
  More...
 
dictionary load_model_custom_objects = {"optimizer": optimizer, "criterion": criterion, "train_func": fit, "predict_func": predict}
  More...
 
 m = torch.jit.script(net)
  More...
 
 net
  More...
 
 optimizer = torch.optim.Adam
  More...
 

Function Documentation

◆ fit()

def PyTorch_Generate_CNN_Model.fit (   model,
  train_loader,
  val_loader,
  num_epochs,
  batch_size,
  optimizer,
  criterion,
  save_best,
  scheduler 
)

Definition at line 32 of file PyTorch_Generate_CNN_Model.py.

◆ predict()

def PyTorch_Generate_CNN_Model.predict (   model,
  test_X,
  batch_size = 100 
)

Definition at line 91 of file PyTorch_Generate_CNN_Model.py.

Variable Documentation

◆ criterion

PyTorch_Generate_CNN_Model.criterion = nn.BCELoss()

Definition at line 28 of file PyTorch_Generate_CNN_Model.py.

◆ load_model_custom_objects

dictionary PyTorch_Generate_CNN_Model.load_model_custom_objects = {"optimizer": optimizer, "criterion": criterion, "train_func": fit, "predict_func": predict}

Definition at line 114 of file PyTorch_Generate_CNN_Model.py.

◆ m

PyTorch_Generate_CNN_Model.m = torch.jit.script(net)

Definition at line 117 of file PyTorch_Generate_CNN_Model.py.

◆ net

PyTorch_Generate_CNN_Model.net
Initial value:
1= torch.nn.Sequential(
2 Reshape(),
3 nn.Conv2d(1, 10, kernel_size=3, padding=1),
4 nn.ReLU(),
5 nn.BatchNorm2d(10),
6 nn.Conv2d(10, 10, kernel_size=3, padding=1),
7 nn.ReLU(),
8 nn.MaxPool2d(kernel_size=2),
9 nn.Flatten(),
10 nn.Linear(10*8*8, 256),
11 nn.ReLU(),
12 nn.Linear(256, 2),
13 nn.Sigmoid()
14 )

Definition at line 12 of file PyTorch_Generate_CNN_Model.py.

◆ optimizer

PyTorch_Generate_CNN_Model.optimizer = torch.optim.Adam

Definition at line 29 of file PyTorch_Generate_CNN_Model.py.