{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "adf502c3",
   "metadata": {},
   "source": [
    "# ApplicationClassificationPyTorch\n",
    "This tutorial shows how to apply a trained model to new data.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:21 PM.</small></i>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48ca7b89",
   "metadata": {},
   "source": [
    "PyTorch has to be imported before ROOT to avoid crashes because of clashing\n",
    "std::regexp symbols that are exported by cppyy.\n",
    "See also: https://github.com/wlav/cppyy/issues/227"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffac825a",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from ROOT import TMVA, TFile, TString, gROOT\n",
    "from array import array\n",
    "from subprocess import call\n",
    "from os.path import isfile"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9047bcf",
   "metadata": {},
   "source": [
    "Setup TMVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82fe751b",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "TMVA.Tools.Instance()\n",
    "TMVA.PyMethodBase.PyInitialize()\n",
    "reader = TMVA.Reader(\"Color:!Silent\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4838ba9a",
   "metadata": {},
   "source": [
    "Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0a68a16",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "fname = str(gROOT.GetTutorialDir()) + \"/machine_learning/data/tmva_class_example.root\"\n",
    "data = TFile.Open(fname)\n",
    "signal = data.Get('TreeS')\n",
    "background = data.Get('TreeB')\n",
    "\n",
    "branches = {}\n",
    "for branch in signal.GetListOfBranches():\n",
    "    branchName = branch.GetName()\n",
    "    branches[branchName] = array('f', [-999])\n",
    "    reader.AddVariable(branchName, branches[branchName])\n",
    "    signal.SetBranchAddress(branchName, branches[branchName])\n",
    "    background.SetBranchAddress(branchName, branches[branchName])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cef6d5b",
   "metadata": {},
   "source": [
    "Define predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "198b2584",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def predict(model, test_X, batch_size=32):\n",
    "    # Set to eval mode\n",
    "    model.eval()\n",
    "\n",
    "    test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))\n",
    "    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    predictions = []\n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(test_loader):\n",
    "            X = data[0]\n",
    "            outputs = model(X)\n",
    "            predictions.append(outputs)\n",
    "        preds = torch.cat(predictions)\n",
    "\n",
    "    return preds.numpy()\n",
    "\n",
    "\n",
    "load_model_custom_objects = {\"optimizer\": None, \"criterion\": None, \"train_func\": None, \"predict_func\": predict}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b42e404",
   "metadata": {},
   "source": [
    "Book methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b43d69b8",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "reader.BookMVA('PyTorch', TString('dataset/weights/TMVAClassification_PyTorch.weights.xml'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "619195fe",
   "metadata": {},
   "source": [
    "Print some example classifications"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07ae62e2",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "print('Some signal example classifications:')\n",
    "for i in range(20):\n",
    "    signal.GetEntry(i)\n",
    "    print(reader.EvaluateMVA('PyTorch'))\n",
    "print('')\n",
    "\n",
    "print('Some background example classifications:')\n",
    "for i in range(20):\n",
    "    background.GetEntry(i)\n",
    "    print(reader.EvaluateMVA('PyTorch'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
