{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "10a9ad64",
   "metadata": {},
   "source": [
    "# TMVA_SOFIE_GNN\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:**   \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:23 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09cc577c",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import ROOT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17a1a208",
   "metadata": {},
   "source": [
    "Load system openblas library explicitly if available. This avoids pulling in\n",
    "NumPys builtin openblas later, which will conflict with the system openblas."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ab63117",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "ROOT.gInterpreter.Load(\"libopenblaso.so\")\n",
    "\n",
    "import time\n",
    "\n",
    "import graph_nets as gn\n",
    "import numpy as np\n",
    "import sonnet as snt\n",
    "from graph_nets import utils_tf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5763951",
   "metadata": {},
   "source": [
    "defining graph properties"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c369be59",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "num_nodes = 5\n",
    "num_edges = 20\n",
    "snd = np.array([1, 2, 3, 4, 2, 3, 4, 3, 4, 4, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3], dtype=\"int32\")\n",
    "rec = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 1, 2, 3, 4, 2, 3, 4, 3, 4, 4], dtype=\"int32\")\n",
    "node_size = 4\n",
    "edge_size = 4\n",
    "global_size = 1\n",
    "LATENT_SIZE = 100\n",
    "NUM_LAYERS = 4\n",
    "processing_steps = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "630e1c6c",
   "metadata": {},
   "source": [
    "method for returning dictionary of graph data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d36c4582",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def get_graph_data_dict(num_nodes, num_edges, NODE_FEATURE_SIZE=2, EDGE_FEATURE_SIZE=2, GLOBAL_FEATURE_SIZE=1):\n",
    "    return {\n",
    "        \"globals\": 10 * np.random.rand(GLOBAL_FEATURE_SIZE).astype(np.float32) - 5.0,\n",
    "        \"nodes\": 10 * np.random.rand(num_nodes, NODE_FEATURE_SIZE).astype(np.float32) - 5.0,\n",
    "        \"edges\": 10 * np.random.rand(num_edges, EDGE_FEATURE_SIZE).astype(np.float32) - 5.0,\n",
    "        \"senders\": snd,\n",
    "        \"receivers\": rec,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d7cc43e",
   "metadata": {},
   "source": [
    "method to instantiate mlp model to be added in GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00e84102",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def make_mlp_model():\n",
    "    return snt.Sequential(\n",
    "        [\n",
    "            snt.nets.MLP([LATENT_SIZE] * NUM_LAYERS, activate_final=True),\n",
    "            snt.LayerNorm(axis=-1, create_offset=True, create_scale=True),\n",
    "        ]\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae51b14f",
   "metadata": {},
   "source": [
    "defining GraphIndependent class with MLP edge, node, and global models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "766ea63d",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class MLPGraphIndependent(snt.Module):\n",
    "    def __init__(self, name=\"MLPGraphIndependent\"):\n",
    "        super(MLPGraphIndependent, self).__init__(name=name)\n",
    "        self._network = gn.modules.GraphIndependent(\n",
    "            edge_model_fn=lambda: snt.nets.MLP([LATENT_SIZE] * NUM_LAYERS, activate_final=True),\n",
    "            node_model_fn=lambda: snt.nets.MLP([LATENT_SIZE] * NUM_LAYERS, activate_final=True),\n",
    "            global_model_fn=lambda: snt.nets.MLP([LATENT_SIZE] * NUM_LAYERS, activate_final=True),\n",
    "        )\n",
    "\n",
    "    def __call__(self, inputs):\n",
    "        return self._network(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "549005f0",
   "metadata": {},
   "source": [
    "defining Graph network class with MLP edge, node, and global models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eb90bf3",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class MLPGraphNetwork(snt.Module):\n",
    "    def __init__(self, name=\"MLPGraphNetwork\"):\n",
    "        super(MLPGraphNetwork, self).__init__(name=name)\n",
    "        self._network = gn.modules.GraphNetwork(\n",
    "            edge_model_fn=make_mlp_model, node_model_fn=make_mlp_model, global_model_fn=make_mlp_model\n",
    "        )\n",
    "\n",
    "    def __call__(self, inputs):\n",
    "        return self._network(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "325fd73e",
   "metadata": {},
   "source": [
    "defining a Encode-Process-Decode module for LHCb toy model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b580f5e",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class EncodeProcessDecode(snt.Module):\n",
    "\n",
    "    def __init__(self, name=\"EncodeProcessDecode\"):\n",
    "        super(EncodeProcessDecode, self).__init__(name=name)\n",
    "        self._encoder = MLPGraphIndependent()\n",
    "        self._core = MLPGraphNetwork()\n",
    "        self._decoder = MLPGraphIndependent()\n",
    "        self._output_transform = MLPGraphIndependent()\n",
    "\n",
    "    def __call__(self, input_op, num_processing_steps):\n",
    "        latent = self._encoder(input_op)\n",
    "        latent0 = latent\n",
    "        output_ops = []\n",
    "        for _ in range(num_processing_steps):\n",
    "            core_input = utils_tf.concat([latent0, latent], axis=1)\n",
    "            latent = self._core(core_input)\n",
    "            decoded_op = self._decoder(latent)\n",
    "            output_ops.append(self._output_transform(decoded_op))\n",
    "        return output_ops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03e0bb78",
   "metadata": {},
   "source": [
    "Instantiating EncodeProcessDecode Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5544099",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "ep_model = EncodeProcessDecode()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19f8e774",
   "metadata": {},
   "source": [
    "Initializing randomized input data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d69eea5",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "GraphData = get_graph_data_dict(num_nodes, num_edges, node_size, edge_size, global_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03ca3f14",
   "metadata": {},
   "source": [
    "input_graphs  is a tuple representing the initial data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "735fff91",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "input_graph_data = utils_tf.data_dicts_to_graphs_tuple([GraphData])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c3ccf71",
   "metadata": {},
   "source": [
    "Initializing randomized input data for core\n",
    "note that the core network has as input a double number of features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36a633ed",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "CoreGraphData = get_graph_data_dict(num_nodes, num_edges, 2 * LATENT_SIZE, 2 * LATENT_SIZE, 2 * LATENT_SIZE)\n",
    "input_core_graph_data = utils_tf.data_dicts_to_graphs_tuple([CoreGraphData])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96e5dff0",
   "metadata": {},
   "source": [
    "initialize graph data for decoder (input is LATENT_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb0f3b24",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "DecodeGraphData = get_graph_data_dict(num_nodes, num_edges, LATENT_SIZE, LATENT_SIZE, LATENT_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a128bf89",
   "metadata": {},
   "source": [
    "Make prediction of GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e04e1e6",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "output_gn = ep_model(input_graph_data, processing_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6994e6f",
   "metadata": {},
   "source": [
    "print(\"---> Input:\\n\",input_graph_data)\n",
    "print(\"\\n\\n------> Input core data:\\n\",input_core_graph_data)\n",
    "print(\"\\n\\n---> Output:\\n\",output_gn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0d51636",
   "metadata": {},
   "source": [
    "Make SOFIE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc9771fb",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "encoder = ROOT.TMVA.Experimental.SOFIE.RModel_GraphIndependent.ParseFromMemory(\n",
    "    ep_model._encoder._network, GraphData, filename=\"gnn_encoder\"\n",
    ")\n",
    "encoder.Generate()\n",
    "encoder.OutputGenerated()\n",
    "\n",
    "core = ROOT.TMVA.Experimental.SOFIE.RModel_GNN.ParseFromMemory(\n",
    "    ep_model._core._network, CoreGraphData, filename=\"gnn_core\"\n",
    ")\n",
    "core.Generate()\n",
    "core.OutputGenerated()\n",
    "\n",
    "decoder = ROOT.TMVA.Experimental.SOFIE.RModel_GraphIndependent.ParseFromMemory(\n",
    "    ep_model._decoder._network, DecodeGraphData, filename=\"gnn_decoder\"\n",
    ")\n",
    "decoder.Generate()\n",
    "decoder.OutputGenerated()\n",
    "\n",
    "output_transform = ROOT.TMVA.Experimental.SOFIE.RModel_GraphIndependent.ParseFromMemory(\n",
    "    ep_model._output_transform._network, DecodeGraphData, filename=\"gnn_output_transform\"\n",
    ")\n",
    "output_transform.Generate()\n",
    "output_transform.OutputGenerated()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c312f627",
   "metadata": {},
   "source": [
    "Compile now the generated C++ code from SOFIE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "880318e4",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "gen_code = '''#pragma cling optimize(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99edaa8f",
   "metadata": {},
   "source": [
    "nclude \"gnn_encoder.hxx\"\n",
    "nclude \"gnn_core.hxx\"\n",
    "nclude \"gnn_decoder.hxx\"\n",
    "nclude \"gnn_output_transform.hxx\"'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0823f702",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "ROOT.gInterpreter.Declare(gen_code)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b355364a",
   "metadata": {},
   "source": [
    "helper function to print SOFIE GNN data structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a5d862d",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def PrintSofie(output, printShape=False):\n",
    "    n = np.asarray(output.node_data)\n",
    "    e = np.asarray(output.edge_data)\n",
    "    g = np.asarray(output.global_data)\n",
    "    if printShape:\n",
    "        print(\"SOFIE data ... shapes\", n.shape, e.shape, g.shape)\n",
    "    print(\n",
    "        \" node data\",\n",
    "        n.reshape(\n",
    "            n.size,\n",
    "        ),\n",
    "    )\n",
    "    print(\n",
    "        \" edge data\",\n",
    "        e.reshape(\n",
    "            e.size,\n",
    "        ),\n",
    "    )\n",
    "    print(\n",
    "        \" global data\",\n",
    "        g.reshape(\n",
    "            g.size,\n",
    "        ),\n",
    "    )\n",
    "\n",
    "\n",
    "def CopyData(input_data):\n",
    "    output_data = ROOT.TMVA.Experimental.SOFIE.Copy(input_data)\n",
    "    return output_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37aac795",
   "metadata": {},
   "source": [
    "Build  SOFIE GNN Model and run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5a7b8bd",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class SofieGNN:\n",
    "    def __init__(self):\n",
    "        self.encoder_session = ROOT.TMVA_SOFIE_gnn_encoder.Session()\n",
    "        self.core_session = ROOT.TMVA_SOFIE_gnn_core.Session()\n",
    "        self.decoder_session = ROOT.TMVA_SOFIE_gnn_decoder.Session()\n",
    "        self.output_transform_session = ROOT.TMVA_SOFIE_gnn_output_transform.Session()\n",
    "\n",
    "    def infer(self, graphData):\n",
    "        # copy the input data\n",
    "        input_data = CopyData(graphData)\n",
    "\n",
    "        # running inference on sofie\n",
    "        self.encoder_session.infer(input_data)\n",
    "        latent0 = CopyData(input_data)\n",
    "        latent = input_data\n",
    "        output_ops = []\n",
    "        for _ in range(processing_steps):\n",
    "            core_input = ROOT.TMVA.Experimental.SOFIE.Concatenate(latent0, latent, axis=1)\n",
    "            self.core_session.infer(core_input)\n",
    "            latent = CopyData(core_input)\n",
    "            self.decoder_session.infer(core_input)\n",
    "            self.output_transform_session.infer(core_input)\n",
    "            output = CopyData(core_input)\n",
    "            output_ops.append(output)\n",
    "\n",
    "        return output_ops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df80186d",
   "metadata": {},
   "source": [
    "Test both GNN on some simulated events"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2235e648",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def GenerateData():\n",
    "    data = get_graph_data_dict(num_nodes, num_edges, node_size, edge_size, global_size)\n",
    "    return data\n",
    "\n",
    "\n",
    "numevts = 40\n",
    "dataSet = []\n",
    "for i in range(0, numevts):\n",
    "    data = GenerateData()\n",
    "    dataSet.append(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8334698b",
   "metadata": {},
   "source": [
    "Run graph_nets model\n",
    "First we convert input data to the required input format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc739711",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "gnetData = []\n",
    "for i in range(0, numevts):\n",
    "    graphData = dataSet[i]\n",
    "    gnet_data_i = utils_tf.data_dicts_to_graphs_tuple([graphData])\n",
    "    gnetData.append(gnet_data_i)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e60cfc0",
   "metadata": {},
   "source": [
    "Function to run the graph net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb0b5bc",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def RunGNet(inputGraphData):\n",
    "    output_gn = ep_model(inputGraphData, processing_steps)\n",
    "    return output_gn\n",
    "\n",
    "\n",
    "start = time.time()\n",
    "hG = ROOT.TH1D(\"hG\", \"Result from graphnet\", 20, 1, 0)\n",
    "for i in range(0, numevts):\n",
    "    out = RunGNet(gnetData[i])\n",
    "    g = out[1].globals.numpy()\n",
    "    hG.Fill(np.mean(g))\n",
    "\n",
    "end = time.time()\n",
    "print(\"elapsed time for \", numevts, \"events = \", end - start)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3fd3d15",
   "metadata": {},
   "source": [
    "running SOFIE-GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05272b01",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "sofieData = []\n",
    "for i in range(0, numevts):\n",
    "    graphData = dataSet[i]\n",
    "    input_data = ROOT.TMVA.Experimental.SOFIE.GNN_Data()\n",
    "    input_data.node_data = ROOT.TMVA.Experimental.AsRTensor(graphData[\"nodes\"])\n",
    "    input_data.edge_data = ROOT.TMVA.Experimental.AsRTensor(graphData[\"edges\"])\n",
    "    input_data.global_data = ROOT.TMVA.Experimental.AsRTensor(graphData[\"globals\"])\n",
    "    input_data.edge_index = ROOT.TMVA.Experimental.AsRTensor(np.stack((graphData[\"receivers\"], graphData[\"senders\"])))\n",
    "    sofieData.append(input_data)\n",
    "\n",
    "\n",
    "endSC = time.time()\n",
    "print(\"time to convert data to SOFIE format\", endSC - end)\n",
    "\n",
    "hS = ROOT.TH1D(\"hS\", \"Result from SOFIE\", 20, 1, 0)\n",
    "start0 = time.time()\n",
    "gnn = SofieGNN()\n",
    "start = time.time()\n",
    "print(\"time to create SOFIE GNN class\", start - start0)\n",
    "for i in range(0, numevts):\n",
    "    # print(\"inference event....\",i)\n",
    "    out = gnn.infer(sofieData[i])\n",
    "    g = np.asarray(out[1].global_data)\n",
    "    hS.Fill(np.mean(g))\n",
    "\n",
    "end = time.time()\n",
    "print(\"elapsed time for \", numevts, \"events = \", end - start)\n",
    "\n",
    "c0 = ROOT.TCanvas()\n",
    "c0.Divide(1, 2)\n",
    "c1 = c0.cd(1)\n",
    "c1.Divide(2, 1)\n",
    "c1.cd(1)\n",
    "hG.Draw()\n",
    "c1.cd(2)\n",
    "hS.Draw()\n",
    "\n",
    "hDe = ROOT.TH1D(\"hDe\", \"Difference for edge data\", 40, 1, 0)\n",
    "hDn = ROOT.TH1D(\"hDn\", \"Difference for node data\", 40, 1, 0)\n",
    "hDg = ROOT.TH1D(\"hDg\", \"Difference for global data\", 40, 1, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "485511b3",
   "metadata": {},
   "source": [
    "compute differences between SOFIE and GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67533c7d",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "for i in range(0, numevts):\n",
    "    outSofie = gnn.infer(sofieData[i])\n",
    "    outGnet = RunGNet(gnetData[i])\n",
    "    edgesG = outGnet[1].edges.numpy()\n",
    "    edgesS = np.asarray(outSofie[1].edge_data)\n",
    "    if i == 0:\n",
    "        print(edgesG.shape)\n",
    "    for j in range(0, edgesG.shape[0]):\n",
    "        for k in range(0, edgesG.shape[1]):\n",
    "            hDe.Fill(edgesG[j, k] - edgesS[j, k])\n",
    "\n",
    "    nodesG = outGnet[1].nodes.numpy()\n",
    "    nodesS = np.asarray(outSofie[1].node_data)\n",
    "    for j in range(0, nodesG.shape[0]):\n",
    "        for k in range(0, nodesG.shape[1]):\n",
    "            hDn.Fill(nodesG[j, k] - nodesS[j, k])\n",
    "\n",
    "    globG = outGnet[1].globals.numpy()\n",
    "    globS = np.asarray(outSofie[1].global_data)\n",
    "    for j in range(0, globG.shape[1]):\n",
    "        hDg.Fill(globG[0, j] - globS[j])\n",
    "\n",
    "\n",
    "c2 = c0.cd(2)\n",
    "c2.Divide(3, 1)\n",
    "c2.cd(1)\n",
    "hDe.Draw()\n",
    "c2.cd(2)\n",
    "hDn.Draw()\n",
    "c2.cd(3)\n",
    "hDg.Draw()\n",
    "\n",
    "c0.Draw()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
