{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5a25fd97",
   "metadata": {},
   "source": [
    "# TMVA_SOFIE_ONNX\n",
    "This macro provides a simple example for:\n",
    " - creating a model with Pytorch and export to ONNX\n",
    " - parsing the ONNX file with SOFIE and generate C++ code\n",
    " - compiling the model using ROOT Cling\n",
    " - run the code and optionally compare with ONNXRuntime\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** Lorenzo Moneta  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:23 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "809160a0",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:23:32.161530Z",
     "iopub.status.busy": "2026-05-19T20:23:32.161412Z",
     "iopub.status.idle": "2026-05-19T20:23:34.146285Z",
     "shell.execute_reply": "2026-05-19T20:23:34.145814Z"
    }
   },
   "outputs": [],
   "source": [
    "import inspect\n",
    "\n",
    "import numpy as np\n",
    "import ROOT\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "def CreateAndTrainModel(modelName):\n",
    "\n",
    "   model = nn.Sequential(\n",
    "           nn.Linear(32,16),\n",
    "           nn.ReLU(),\n",
    "           nn.Linear(16,8),\n",
    "           nn.ReLU(),\n",
    "           nn.Linear(8,2),\n",
    "           nn.Softmax(dim=1)\n",
    "           )\n",
    "\n",
    "   criterion = nn.MSELoss()\n",
    "   optimizer = torch.optim.SGD(model.parameters(),lr=0.01)\n",
    "\n",
    "\n",
    "   #train model with the random data\n",
    "   for i in range(500):\n",
    "      x=torch.randn(2,32)\n",
    "      y=torch.randn(2,2)\n",
    "      y_pred = model(x)\n",
    "      loss = criterion(y_pred,y)\n",
    "      optimizer.zero_grad()\n",
    "      loss.backward()\n",
    "      optimizer.step()\n",
    "\n",
    "   #*******************************************************\n",
    "   ##  EXPORT to ONNX\n",
    "   #\n",
    "   #  need to evaluate the model before exporting to ONNX\n",
    "   #  and to provide a dummy input tensor to set the input model shape\n",
    "   model.eval()\n",
    "\n",
    "   modelFile = modelName + \".onnx\"\n",
    "   dummy_x = torch.randn(1,32)\n",
    "   model(dummy_x)\n",
    "\n",
    "   #check for torch.onnx.export parameters\n",
    "   def filtered_kwargs(func, **candidate_kwargs):\n",
    "    sig = inspect.signature(func)\n",
    "    return {\n",
    "        k: v for k, v in candidate_kwargs.items()\n",
    "        if k in sig.parameters\n",
    "   }\n",
    "   kwargs = filtered_kwargs(\n",
    "      torch.onnx.export,\n",
    "      input_names=[\"input\"],\n",
    "      output_names=[\"output\"],\n",
    "      external_data=False,  # may not exist\n",
    "      dynamo=True           # may not exist\n",
    "   )\n",
    "   print(\"calling torch.onnx.export with parameters\",kwargs)\n",
    "\n",
    "   try:\n",
    "      torch.onnx.export(model, dummy_x, modelFile, **kwargs)\n",
    "      print(\"model exported to ONNX as\",modelFile)\n",
    "      return modelFile\n",
    "   except TypeError:\n",
    "      print(\"Cannot export model from pytorch to ONNX - with version \",torch.__version__)\n",
    "      print(\"Skip tutorial execution\")\n",
    "      exit()\n",
    "\n",
    "\n",
    "def ParseModel(modelFile, verbose=False):\n",
    "\n",
    "   parser = ROOT.TMVA.Experimental.SOFIE.RModelParser_ONNX()\n",
    "   model = parser.Parse(modelFile,verbose)\n",
    "   #\n",
    "   #print model weights\n",
    "   if (verbose):\n",
    "      model.PrintInitializedTensors()\n",
    "      data = model.GetTensorData['float']('0weight')\n",
    "      print(\"0weight\",data)\n",
    "      data = model.GetTensorData['float']('2weight')\n",
    "      print(\"2weight\",data)\n",
    "\n",
    "   # Generating inference code\n",
    "   model.Generate()\n",
    "   #generate header file (and .dat file) with modelName+.hxx\n",
    "   model.OutputGenerated()\n",
    "   if (verbose) :\n",
    "       model.PrintGenerated()\n",
    "\n",
    "   modelCode = modelFile.replace(\".onnx\",\".hxx\")\n",
    "   print(\"Generated model header file \",modelCode)\n",
    "   return modelCode"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3733939c",
   "metadata": {},
   "source": [
    "#################################################################\n",
    " Step 1 : Create and Train model\n",
    "#################################################################"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a69dd0e",
   "metadata": {},
   "source": [
    "se an arbitrary modelName"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d35bb14d",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:23:34.153588Z",
     "iopub.status.busy": "2026-05-19T20:23:34.153359Z",
     "iopub.status.idle": "2026-05-19T20:24:05.500626Z",
     "shell.execute_reply": "2026-05-19T20:24:05.499981Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "calling torch.onnx.export with parameters {'input_names': ['input'], 'output_names': ['output'], 'external_data': False, 'dynamo': True}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0519 20:24:04.782000 571221 torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::nms\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0519 20:24:04.785000 571221 torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::roi_align\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0519 20:24:04.789000 571221 torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::roi_pool\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[torch.onnx] Obtain model graph for `Sequential([...]` with `torch.export.export(..., strict=False)`...\n",
      "[torch.onnx] Obtain model graph for `Sequential([...]` with `torch.export.export(..., strict=False)`... ✅\n",
      "[torch.onnx] Run decompositions...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[torch.onnx] Run decompositions... ✅\n",
      "[torch.onnx] Translate the graph into ONNX...\n",
      "[torch.onnx] Translate the graph into ONNX... ✅\n",
      "[torch.onnx] Optimize the ONNX graph...\n",
      "[torch.onnx] Optimize the ONNX graph... ✅\n",
      "model exported to ONNX as LinearModel.onnx\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib64/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.\n",
      "  return cls.__new__(cls, *args)\n"
     ]
    }
   ],
   "source": [
    "modelName = \"LinearModel\"\n",
    "modelFile = CreateAndTrainModel(modelName)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cf46fc9",
   "metadata": {},
   "source": [
    "#################################################################\n",
    " Step 2 : Parse model and generate inference code with SOFIE\n",
    "#################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb015b5b",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:24:05.516091Z",
     "iopub.status.busy": "2026-05-19T20:24:05.515729Z",
     "iopub.status.idle": "2026-05-19T20:24:05.844277Z",
     "shell.execute_reply": "2026-05-19T20:24:05.843562Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated model header file  LinearModel.hxx\n"
     ]
    }
   ],
   "source": [
    "modelCode = ParseModel(modelFile, False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "250e8330",
   "metadata": {},
   "source": [
    "#################################################################\n",
    " Step 3 : Compile the generated C++ model code\n",
    "#################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "52dd0d41",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:24:05.860432Z",
     "iopub.status.busy": "2026-05-19T20:24:05.860252Z",
     "iopub.status.idle": "2026-05-19T20:24:05.979351Z",
     "shell.execute_reply": "2026-05-19T20:24:05.978627Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ROOT.gInterpreter.Declare('#include \"' + modelCode + '\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78a35a02",
   "metadata": {},
   "source": [
    "#################################################################\n",
    " Step 4: Evaluate the model\n",
    "#################################################################"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "336acfa9",
   "metadata": {},
   "source": [
    "et first the SOFIE session namespace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "07f56bc6",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:24:05.984311Z",
     "iopub.status.busy": "2026-05-19T20:24:05.984148Z",
     "iopub.status.idle": "2026-05-19T20:24:06.188654Z",
     "shell.execute_reply": "2026-05-19T20:24:06.187570Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "************************************************************\n",
      "Running inference with SOFIE \n",
      "\n",
      "input to model is  [[ 0.20145239  0.33877108  1.2063044   0.30718967 -1.3833711   0.24564853\n",
      "   0.9694334   1.8081402   0.21905755  1.2830325  -0.6999619  -0.88241714\n",
      "   1.0476316  -0.8708211  -1.7136286   0.6148492  -1.0894874   2.3423162\n",
      "  -0.6901215  -0.10551826 -0.31149137  0.5782308   0.31861544 -1.5678713\n",
      "   0.21121716  0.89408857  0.53305787  0.30537423  0.6382517   0.68332654\n",
      "   0.8655148  -2.026557  ]]\n"
     ]
    }
   ],
   "source": [
    "sofie = getattr(ROOT, 'TMVA_SOFIE_' + modelName)\n",
    "session = sofie.Session()\n",
    "\n",
    "x = np.random.normal(0,1,(1,32)).astype(np.float32)\n",
    "print(\"\\n************************************************************\")\n",
    "print(\"Running inference with SOFIE \")\n",
    "print(\"\\ninput to model is \",x)\n",
    "y = session.infer(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e20739d6",
   "metadata": {},
   "source": [
    "output shape is (1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1261834f",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:24:06.190068Z",
     "iopub.status.busy": "2026-05-19T20:24:06.189924Z",
     "iopub.status.idle": "2026-05-19T20:24:06.301379Z",
     "shell.execute_reply": "2026-05-19T20:24:06.300239Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-> output using SOFIE =  [0.461803 0.538197]\n"
     ]
    }
   ],
   "source": [
    "y_sofie = np.asarray(y.data())\n",
    "print(\"-> output using SOFIE = \", y_sofie)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eeac8cee",
   "metadata": {},
   "source": [
    "heck inference with onnx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6645ad8a",
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2026-05-19T20:24:06.302928Z",
     "iopub.status.busy": "2026-05-19T20:24:06.302770Z",
     "iopub.status.idle": "2026-05-19T20:24:06.408020Z",
     "shell.execute_reply": "2026-05-19T20:24:06.406969Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Missing ONNXRuntime: skipping comparison test\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "   import onnxruntime as ort\n",
    "    # Load model\n",
    "   print(\"Running inference with ONNXRuntime \")\n",
    "   ort_session = ort.InferenceSession(modelFile)\n",
    "\n",
    "   # Run inference\n",
    "   outputs = ort_session.run(None, {\"input\": x})\n",
    "   y_ort = outputs[0]\n",
    "   print(\"-> output using ORT =\", y_ort)\n",
    "\n",
    "   testFailed =  abs(y_sofie-y_ort) > 0.01\n",
    "   if (np.any(testFailed)):\n",
    "      raiseError('Result is different between SOFIE and ONNXRT')\n",
    "   else :\n",
    "      print(\"OK\")\n",
    "\n",
    "except ImportError:\n",
    "   print(\"Missing ONNXRuntime: skipping comparison test\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
