{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a454b7eb",
   "metadata": {},
   "source": [
    "# MulticlassPyTorch\n",
    "This tutorial shows how to do multiclass classification in TMVA with neural\n",
    "networks trained with PyTorch.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:22 PM.</small></i>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5dd4fbe",
   "metadata": {},
   "source": [
    "PyTorch has to be imported before ROOT to avoid crashes because of clashing\n",
    "std::regexp symbols that are exported by cppyy.\n",
    "See also: https://github.com/wlav/cppyy/issues/227"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49cedb88",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from ROOT import TMVA, TFile, TCut, gROOT\n",
    "from os.path import isfile"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "233f27d3",
   "metadata": {},
   "source": [
    "Setup TMVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7367943",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "TMVA.Tools.Instance()\n",
    "TMVA.PyMethodBase.PyInitialize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddda607a",
   "metadata": {},
   "source": [
    "create factory without output file since it is not needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95860962",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "factory = TMVA.Factory('TMVAClassification',\n",
    "    '!V:!Silent:Color:!DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2069e94",
   "metadata": {},
   "source": [
    "Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7dd910c",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if not isfile('tmva_example_multiple_background.root'):\n",
    "    createDataMacro = str(gROOT.GetTutorialDir()) + '/machine_learning/createData.C'\n",
    "    print(createDataMacro)\n",
    "    gROOT.ProcessLine('.L {}'.format(createDataMacro))\n",
    "    gROOT.ProcessLine('create_MultipleBackground(4000)')\n",
    "\n",
    "data = TFile.Open('tmva_example_multiple_background.root')\n",
    "signal = data.Get('TreeS')\n",
    "background0 = data.Get('TreeB0')\n",
    "background1 = data.Get('TreeB1')\n",
    "background2 = data.Get('TreeB2')\n",
    "\n",
    "dataloader = TMVA.DataLoader('dataset')\n",
    "for branch in signal.GetListOfBranches():\n",
    "    dataloader.AddVariable(branch.GetName())\n",
    "\n",
    "dataloader.AddTree(signal, 'Signal')\n",
    "dataloader.AddTree(background0, 'Background_0')\n",
    "dataloader.AddTree(background1, 'Background_1')\n",
    "dataloader.AddTree(background2, 'Background_2')\n",
    "dataloader.PrepareTrainingAndTestTree(TCut(''),\n",
    "        'SplitMode=Random:NormMode=NumEvents:!V')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f295f2ea",
   "metadata": {},
   "source": [
    "Generate model\n",
    "Define model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d80a5220",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = nn.Sequential()\n",
    "model.add_module('linear_1', nn.Linear(in_features=4, out_features=32))\n",
    "model.add_module('relu', nn.ReLU())\n",
    "model.add_module('linear_2', nn.Linear(in_features=32, out_features=4))\n",
    "model.add_module('softmax', nn.Softmax(dim=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c5d8435",
   "metadata": {},
   "source": [
    "Set loss and optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "972e4d3c",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "loss = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7638a748",
   "metadata": {},
   "source": [
    "Define train function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6758a7de",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):\n",
    "    trainer = optimizer(model.parameters(), lr=0.01)\n",
    "    schedule, schedulerSteps = scheduler\n",
    "    best_val = None\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        # Training Loop\n",
    "        # Set to train mode\n",
    "        model.train()\n",
    "        running_train_loss = 0.0\n",
    "        running_val_loss = 0.0\n",
    "        for i, (X, y) in enumerate(train_loader):\n",
    "            trainer.zero_grad()\n",
    "            output = model(X)\n",
    "            target = torch.max(y, 1)[1]\n",
    "            train_loss = criterion(output, target)\n",
    "            train_loss.backward()\n",
    "            trainer.step()\n",
    "\n",
    "            # print train statistics\n",
    "            running_train_loss += train_loss.item()\n",
    "            if i % 32 == 31:    # print every 32 mini-batches\n",
    "                print(\"[{}, {}] train loss: {:.3f}\".format(epoch+1, i+1, running_train_loss / 32))\n",
    "                running_train_loss = 0.0\n",
    "\n",
    "        if schedule:\n",
    "            schedule(optimizer, epoch, schedulerSteps)\n",
    "\n",
    "        # Validation Loop\n",
    "        # Set to eval mode\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            for i, (X, y) in enumerate(val_loader):\n",
    "                output = model(X)\n",
    "                target = torch.max(y, 1)[1]\n",
    "                val_loss = criterion(output, target)\n",
    "                running_val_loss += val_loss.item()\n",
    "\n",
    "            curr_val = running_val_loss / len(val_loader)\n",
    "            if save_best:\n",
    "               if best_val==None:\n",
    "                   best_val = curr_val\n",
    "               best_val = save_best(model, curr_val, best_val)\n",
    "\n",
    "            # print val statistics per epoch\n",
    "            print(\"[{}] val loss: {:.3f}\".format(epoch+1, curr_val))\n",
    "            running_val_loss = 0.0\n",
    "\n",
    "    print(\"Finished Training on {} Epochs!\".format(epoch+1))\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3d7388f",
   "metadata": {},
   "source": [
    "Define predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a156e7b6",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def predict(model, test_X, batch_size=32):\n",
    "    # Set to eval mode\n",
    "    model.eval()\n",
    "\n",
    "    test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))\n",
    "    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    predictions = []\n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(test_loader):\n",
    "            X = data[0]\n",
    "            outputs = model(X)\n",
    "            predictions.append(outputs)\n",
    "        preds = torch.cat(predictions)\n",
    "\n",
    "    return preds.numpy()\n",
    "\n",
    "\n",
    "load_model_custom_objects = {\"optimizer\": optimizer, \"criterion\": loss, \"train_func\": train, \"predict_func\": predict}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59cc1020",
   "metadata": {},
   "source": [
    "Store model to file\n",
    "Convert the model to torchscript before saving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e9b9276",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "m = torch.jit.script(model)\n",
    "torch.jit.save(m, \"modelMultiClass.pt\")\n",
    "print(m)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48656192",
   "metadata": {},
   "source": [
    "Book methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e2e6949",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',\n",
    "        '!H:!V:Fisher:VarTransform=D,G')\n",
    "factory.BookMethod(dataloader, TMVA.Types.kPyTorch, \"PyTorch\",\n",
    "        'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.pt:FilenameTrainedModel=trainedModelMultiClass.pt:NumEpochs=20:BatchSize=32')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29e029e8",
   "metadata": {},
   "source": [
    "Run TMVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca51fe7b",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "factory.TrainAllMethods()\n",
    "factory.TestAllMethods()\n",
    "factory.EvaluateAllMethods()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8039d35d",
   "metadata": {},
   "source": [
    "Plot ROC Curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fd76235",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "roc = factory.GetROCCurve(dataloader)\n",
    "roc.SaveAs('ROC_MulticlassPyTorch.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
