{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5d72d2a3",
   "metadata": {},
   "source": [
    "# PyTorch_Generate_CNN_Model\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** Harshal Shende  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:22 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "38a026d3",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:22:23.779944Z",
     "iopub.status.busy": "2026-05-19T20:22:23.779828Z",
     "iopub.status.idle": "2026-05-19T20:22:24.739245Z",
     "shell.execute_reply": "2026-05-19T20:22:24.738642Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f0c1e1d",
   "metadata": {},
   "source": [
    "Define model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "93a5ad47",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:22:24.741563Z",
     "iopub.status.busy": "2026-05-19T20:22:24.741375Z",
     "iopub.status.idle": "2026-05-19T20:22:24.744446Z",
     "shell.execute_reply": "2026-05-19T20:22:24.743917Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "running Torch code defining the model....\n"
     ]
    }
   ],
   "source": [
    "print(\"running Torch code defining the model....\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "521f444a",
   "metadata": {},
   "source": [
    "Custom Reshape Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9499c50a",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:22:24.745895Z",
     "iopub.status.busy": "2026-05-19T20:22:24.745778Z",
     "iopub.status.idle": "2026-05-19T20:22:24.748826Z",
     "shell.execute_reply": "2026-05-19T20:22:24.747902Z"
    }
   },
   "outputs": [],
   "source": [
    "class Reshape(torch.nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x.view(-1,1,16,16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "700264ba",
   "metadata": {},
   "source": [
    "CNN Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "536b90b0",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:22:24.750096Z",
     "iopub.status.busy": "2026-05-19T20:22:24.749978Z",
     "iopub.status.idle": "2026-05-19T20:22:24.755036Z",
     "shell.execute_reply": "2026-05-19T20:22:24.754233Z"
    }
   },
   "outputs": [],
   "source": [
    "net = torch.nn.Sequential(\n",
    "    Reshape(),\n",
    "    nn.Conv2d(1, 10, kernel_size=3, padding=1),\n",
    "    nn.ReLU(),\n",
    "    nn.BatchNorm2d(10),\n",
    "    nn.Conv2d(10, 10, kernel_size=3, padding=1),\n",
    "    nn.ReLU(),\n",
    "    nn.MaxPool2d(kernel_size=2),\n",
    "    nn.Flatten(),\n",
    "    nn.Linear(10*8*8, 256),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(256, 2),\n",
    "    nn.Sigmoid()\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7e37747",
   "metadata": {},
   "source": [
    "Construct loss function and Optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "55d37d54",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:22:24.756395Z",
     "iopub.status.busy": "2026-05-19T20:22:24.756271Z",
     "iopub.status.idle": "2026-05-19T20:22:24.763303Z",
     "shell.execute_reply": "2026-05-19T20:22:24.762778Z"
    }
   },
   "outputs": [],
   "source": [
    "criterion = nn.BCELoss()\n",
    "optimizer = torch.optim.Adam\n",
    "\n",
    "\n",
    "def fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):\n",
    "    trainer = optimizer(model.parameters(), lr=0.01)\n",
    "    schedule, schedulerSteps = scheduler\n",
    "    best_val = None\n",
    "\n",
    "    # Setup GPU\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    model = model.to(device)\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        # Training Loop\n",
    "        # Set to train mode\n",
    "        model.train()\n",
    "        running_train_loss = 0.0\n",
    "        running_val_loss = 0.0\n",
    "        for i, (X, y) in enumerate(train_loader):\n",
    "            trainer.zero_grad()\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            output = model(X)\n",
    "            target = y\n",
    "            train_loss = criterion(output, target)\n",
    "            train_loss.backward()\n",
    "            trainer.step()\n",
    "\n",
    "            # print train statistics\n",
    "            running_train_loss += train_loss.item()\n",
    "            if i % 4 == 3:    # print every 4 mini-batches\n",
    "                print(f\"[{epoch+1}, {i+1}] train loss: {running_train_loss / 4 :.3f}\")\n",
    "                running_train_loss = 0.0\n",
    "\n",
    "        if schedule:\n",
    "            schedule(optimizer, epoch, schedulerSteps)\n",
    "\n",
    "        # Validation Loop\n",
    "        # Set to eval mode\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            for i, (X, y) in enumerate(val_loader):\n",
    "                X, y = X.to(device), y.to(device)\n",
    "                output = model(X)\n",
    "                target = y\n",
    "                val_loss = criterion(output, target)\n",
    "                running_val_loss += val_loss.item()\n",
    "\n",
    "            curr_val = running_val_loss / len(val_loader)\n",
    "            if save_best:\n",
    "               if best_val==None:\n",
    "                   best_val = curr_val\n",
    "               best_val = save_best(model, curr_val, best_val)\n",
    "\n",
    "            # print val statistics per epoch\n",
    "            print(f\"[{epoch+1}] val loss: {curr_val :.3f}\")\n",
    "            running_val_loss = 0.0\n",
    "\n",
    "    print(f\"Finished Training on {epoch+1} Epochs!\")\n",
    "\n",
    "    return model\n",
    "\n",
    "\n",
    "def predict(model, test_X, batch_size=100):\n",
    "    # Set to eval mode\n",
    "\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    model = model.to(device)\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "\n",
    "    test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))\n",
    "    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    predictions = []\n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(test_loader):\n",
    "            X = data[0].to(device)\n",
    "            outputs = model(X)\n",
    "            predictions.append(outputs)\n",
    "        preds = torch.cat(predictions)\n",
    "\n",
    "    return preds.cpu().numpy()\n",
    "\n",
    "\n",
    "load_model_custom_objects = {\"optimizer\": optimizer, \"criterion\": criterion, \"train_func\": fit, \"predict_func\": predict}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51afa710",
   "metadata": {},
   "source": [
    "Store model to file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0d74f1cf",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:22:24.764764Z",
     "iopub.status.busy": "2026-05-19T20:22:24.764643Z",
     "iopub.status.idle": "2026-05-19T20:22:24.798041Z",
     "shell.execute_reply": "2026-05-19T20:22:24.797376Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The PyTorch CNN model is created and saved as PyTorchModelCNN.pt\n"
     ]
    }
   ],
   "source": [
    "m = torch.jit.script(net)\n",
    "torch.jit.save(m,\"PyTorchModelCNN.pt\")\n",
    "print(\"The PyTorch CNN model is created and saved as PyTorchModelCNN.pt\") "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a42a79f8",
   "metadata": {},
   "source": [
    "Draw all canvases "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d20bfe27",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:22:24.799617Z",
     "iopub.status.busy": "2026-05-19T20:22:24.799479Z",
     "iopub.status.idle": "2026-05-19T20:22:25.736009Z",
     "shell.execute_reply": "2026-05-19T20:22:25.735552Z"
    }
   },
   "outputs": [],
   "source": [
    "from ROOT import gROOT \n",
    "gROOT.GetListOfCanvases().Draw()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
