{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "46e807d0",
   "metadata": {},
   "source": [
    "# RegressionPyTorch\n",
    "This tutorial shows how to do regression in TMVA with neural networks\n",
    "trained with PyTorch.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** Anirudh Dagar <anirudhdagar6@gmail.com> - IIT, Roorkee  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:22 PM.</small></i>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ee27059",
   "metadata": {},
   "source": [
    "PyTorch has to be imported before ROOT to avoid crashes because of clashing\n",
    "std::regexp symbols that are exported by cppyy.\n",
    "See also: https://github.com/wlav/cppyy/issues/227"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d50ce43",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from ROOT import TMVA, TFile, TCut, gROOT\n",
    "from subprocess import call\n",
    "from os.path import isfile"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "978e6d8e",
   "metadata": {},
   "source": [
    "Setup TMVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c1c7ce4",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "TMVA.Tools.Instance()\n",
    "TMVA.PyMethodBase.PyInitialize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0960a85",
   "metadata": {},
   "source": [
    "create factory without output file since it is not needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45be37d3",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "factory = TMVA.Factory('TMVARegression',\n",
    "        '!V:!Silent:Color:!DrawProgressBar:Transformations=D,G:AnalysisType=Regression')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d6ae63f",
   "metadata": {},
   "source": [
    "Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89574749",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data = TFile.Open(str(gROOT.GetTutorialDir()) + '/machine_learning/data/tmva_reg_example.root')\n",
    "tree = data.Get('TreeR')\n",
    "\n",
    "dataloader = TMVA.DataLoader('dataset')\n",
    "for branch in tree.GetListOfBranches():\n",
    "    name = branch.GetName()\n",
    "    if name != 'fvalue':\n",
    "        dataloader.AddVariable(name)\n",
    "dataloader.AddTarget('fvalue')\n",
    "\n",
    "dataloader.AddRegressionTree(tree, 1.0)\n",
    "dataloader.PrepareTrainingAndTestTree(TCut(''),\n",
    "        'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32de967b",
   "metadata": {},
   "source": [
    "Generate model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52ef6876",
   "metadata": {},
   "source": [
    "Define model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08c37c33",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = nn.Sequential()\n",
    "model.add_module('linear_1', nn.Linear(in_features=2, out_features=64))\n",
    "model.add_module('relu', nn.Tanh())\n",
    "model.add_module('linear_2', nn.Linear(in_features=64, out_features=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1c3c3cc",
   "metadata": {},
   "source": [
    "Construct loss function and Optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf003e81",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "loss = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.SGD"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c9812f1",
   "metadata": {},
   "source": [
    "Define train function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "076af9b1",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):\n",
    "    trainer = optimizer(model.parameters(), lr=0.01)\n",
    "    schedule, schedulerSteps = scheduler\n",
    "    best_val = None\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        # Training Loop\n",
    "        # Set to train mode\n",
    "        model.train()\n",
    "        running_train_loss = 0.0\n",
    "        running_val_loss = 0.0\n",
    "        for i, (X, y) in enumerate(train_loader):\n",
    "            trainer.zero_grad()\n",
    "            output = model(X)\n",
    "            train_loss = criterion(output, y)\n",
    "            train_loss.backward()\n",
    "            trainer.step()\n",
    "\n",
    "            # print train statistics\n",
    "            running_train_loss += train_loss.item()\n",
    "            if i % 32 == 31:    # print every 32 mini-batches\n",
    "                print(\"[{}, {}] train loss: {:.3f}\".format(epoch+1, i+1, running_train_loss / 32))\n",
    "                running_train_loss = 0.0\n",
    "\n",
    "        if schedule:\n",
    "            schedule(optimizer, epoch, schedulerSteps)\n",
    "\n",
    "        # Validation Loop\n",
    "        # Set to eval mode\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            for i, (X, y) in enumerate(val_loader):\n",
    "                output = model(X)\n",
    "                val_loss = criterion(output, y)\n",
    "                running_val_loss += val_loss.item()\n",
    "\n",
    "            curr_val = running_val_loss / len(val_loader)\n",
    "            if save_best:\n",
    "               if best_val==None:\n",
    "                   best_val = curr_val\n",
    "               best_val = save_best(model, curr_val, best_val)\n",
    "\n",
    "            # print val statistics per epoch\n",
    "            print(\"[{}] val loss: {:.3f}\".format(epoch+1, curr_val))\n",
    "            running_val_loss = 0.0\n",
    "\n",
    "    print(\"Finished Training on {} Epochs!\".format(epoch+1))\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "121a8518",
   "metadata": {},
   "source": [
    "Define predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7016025",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def predict(model, test_X, batch_size=32):\n",
    "    # Set to eval mode\n",
    "    model.eval()\n",
    "\n",
    "    test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))\n",
    "    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    predictions = []\n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(test_loader):\n",
    "            X = data[0]\n",
    "            outputs = model(X)\n",
    "            predictions.append(outputs)\n",
    "        preds = torch.cat(predictions)\n",
    "\n",
    "    return preds.numpy()\n",
    "\n",
    "\n",
    "load_model_custom_objects = {\"optimizer\": optimizer, \"criterion\": loss, \"train_func\": train, \"predict_func\": predict}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44bc78b3",
   "metadata": {},
   "source": [
    "Store model to file\n",
    "Convert the model to torchscript before saving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "575f061e",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "m = torch.jit.script(model)\n",
    "torch.jit.save(m, \"modelRegression.pt\")\n",
    "print(m)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ab1d6e0",
   "metadata": {},
   "source": [
    "Book methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85db2321",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "factory.BookMethod(dataloader, TMVA.Types.kPyTorch, 'PyTorch',\n",
    "        'H:!V:VarTransform=D,G:FilenameModel=modelRegression.pt:FilenameTrainedModel=trainedModelRegression.pt:NumEpochs=20:BatchSize=32')\n",
    "factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',\n",
    "        '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31d89d02",
   "metadata": {},
   "source": [
    "Run TMVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2f62ccb",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "factory.TrainAllMethods()\n",
    "factory.TestAllMethods()\n",
    "factory.EvaluateAllMethods()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}