{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f92e89f8",
   "metadata": {},
   "source": [
    "# TMVA_SOFIE_GNN_Parser\n",
    "\n",
    "Tutorial showing how to parse a GNN from GraphNet and make a SOFIE model\n",
    "The tutorial also generate some  data which can serve as input for the tutorial TMVA_SOFIE_GNN_Application.C\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:**   \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:23 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3a128e0",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cd22466",
   "metadata": {},
   "source": [
    "or getting time and memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011b0a0e",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import graph_nets as gn\n",
    "import numpy as np\n",
    "import psutil\n",
    "import ROOT\n",
    "import sonnet as snt\n",
    "from graph_nets import utils_tf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04501e47",
   "metadata": {},
   "source": [
    "defining graph properties. Number of edges/modes are the maximum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21544c62",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "num_max_nodes=100\n",
    "num_max_edges=300\n",
    "node_size=4\n",
    "edge_size=4\n",
    "global_size=1\n",
    "LATENT_SIZE = 100\n",
    "NUM_LAYERS = 4\n",
    "processing_steps = 5\n",
    "numevts = 100\n",
    "\n",
    "verbose = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6a4ffb6",
   "metadata": {},
   "source": [
    "rint the used memory in MB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93e596ef",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def printMemory(s = \"\") :\n",
    "    #get memory of current process\n",
    "    pid = os.getpid()\n",
    "    python_process = psutil.Process(pid)\n",
    "    memoryUse = python_process.memory_info()[0]/(1024.*1024.)    #divide by 1024 * 1024 to get memory in MB\n",
    "    print(s,\"memory:\",memoryUse,\"(MB)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3ecd88a",
   "metadata": {},
   "source": [
    "method for returning dictionary of graph data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdb1c08d",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def get_dynamic_graph_data_dict(NODE_FEATURE_SIZE=2, EDGE_FEATURE_SIZE=2, GLOBAL_FEATURE_SIZE=1):\n",
    "   num_nodes = np.random.randint(num_max_nodes-2, size=1)[0] + 2\n",
    "   num_edges = np.random.randint(num_max_edges-1, size=1)[0] + 1\n",
    "   return {\n",
    "      \"globals\": 10*np.random.rand(GLOBAL_FEATURE_SIZE).astype(np.float32)-5.,\n",
    "      \"nodes\": 10*np.random.rand(num_nodes, NODE_FEATURE_SIZE).astype(np.float32)-5.,\n",
    "      \"edges\": 10*np.random.rand(num_edges, EDGE_FEATURE_SIZE).astype(np.float32)-5.,\n",
    "      \"senders\": np.random.randint(num_nodes, size=num_edges, dtype=np.int32),\n",
    "      \"receivers\": np.random.randint(num_nodes, size=num_edges, dtype=np.int32)\n",
    "   }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cbe1b99",
   "metadata": {},
   "source": [
    "generate graph data with a fixed number of nodes/edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c990644",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def get_fix_graph_data_dict(num_nodes, num_edges, NODE_FEATURE_SIZE=2, EDGE_FEATURE_SIZE=2, GLOBAL_FEATURE_SIZE=1):\n",
    "   return {\n",
    "      \"globals\": np.ones((GLOBAL_FEATURE_SIZE),dtype=np.float32),\n",
    "      \"nodes\": np.ones((num_nodes, NODE_FEATURE_SIZE), dtype = np.float32),\n",
    "      \"edges\": np.ones((num_edges, EDGE_FEATURE_SIZE), dtype = np.float32),\n",
    "      \"senders\":  np.random.randint(num_nodes, size=num_edges, dtype=np.int32),\n",
    "      \"receivers\": np.random.randint(num_nodes, size=num_edges, dtype=np.int32)\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb55b208",
   "metadata": {},
   "source": [
    "method to instantiate mlp model to be added in GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2851d406",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def make_mlp_model():\n",
    "  return snt.Sequential([\n",
    "      snt.nets.MLP([LATENT_SIZE]*NUM_LAYERS, activate_final=True),\n",
    "      snt.LayerNorm(axis=-1, create_offset=True, create_scale=True)\n",
    "  ])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c67c19a",
   "metadata": {},
   "source": [
    "defining GraphIndependent class with MLP edge, node, and global models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f154e14",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class MLPGraphIndependent(snt.Module):\n",
    "  def __init__(self, name=\"MLPGraphIndependent\"):\n",
    "    super(MLPGraphIndependent, self).__init__(name=name)\n",
    "    self._network = gn.modules.GraphIndependent(\n",
    "        edge_model_fn = lambda: snt.nets.MLP([LATENT_SIZE]*NUM_LAYERS, activate_final=True),\n",
    "        node_model_fn = lambda: snt.nets.MLP([LATENT_SIZE]*NUM_LAYERS, activate_final=True),\n",
    "        global_model_fn = lambda: snt.nets.MLP([LATENT_SIZE]*NUM_LAYERS, activate_final=True))\n",
    "\n",
    "  def __call__(self, inputs):\n",
    "    return self._network(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25061527",
   "metadata": {},
   "source": [
    "defining Graph network class with MLP edge, node, and global models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "670b99ad",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class MLPGraphNetwork(snt.Module):\n",
    "  def __init__(self, name=\"MLPGraphNetwork\"):\n",
    "    super(MLPGraphNetwork, self).__init__(name=name)\n",
    "    self._network = gn.modules.GraphNetwork(\n",
    "            edge_model_fn=make_mlp_model,\n",
    "            node_model_fn=make_mlp_model,\n",
    "            global_model_fn=make_mlp_model)\n",
    "\n",
    "  def __call__(self, inputs):\n",
    "    return self._network(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28f68619",
   "metadata": {},
   "source": [
    "defining a Encode-Process-Decode module for LHCb toy model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f66ffb4",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class EncodeProcessDecode(snt.Module):\n",
    "\n",
    "  def __init__(self,\n",
    "               name=\"EncodeProcessDecode\"):\n",
    "    super(EncodeProcessDecode, self).__init__(name=name)\n",
    "    self._encoder = MLPGraphIndependent()\n",
    "    self._core = MLPGraphNetwork()\n",
    "    self._decoder = MLPGraphIndependent()\n",
    "    self._output_transform = MLPGraphIndependent()\n",
    "\n",
    "  def __call__(self, input_op, num_processing_steps):\n",
    "    latent = self._encoder(input_op)\n",
    "    latent0 = latent\n",
    "    output_ops = []\n",
    "    for _ in range(num_processing_steps):\n",
    "      core_input = utils_tf.concat([latent0, latent], axis=1)\n",
    "      latent = self._core(core_input)\n",
    "      decoded_op = self._decoder(latent)\n",
    "      output_ops.append(self._output_transform(decoded_op))\n",
    "    return output_ops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca2fc690",
   "metadata": {},
   "source": [
    "######################################################################################################"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "878b6925",
   "metadata": {},
   "source": [
    "Instantiating EncodeProcessDecode Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f1fe84",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "printMemory(\"before instantiating\")\n",
    "ep_model = EncodeProcessDecode()\n",
    "printMemory(\"after instantiating\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70623dc7",
   "metadata": {},
   "source": [
    "Initializing randomized input data with maximum number of nodes/edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a80b8e17",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "GraphData = get_fix_graph_data_dict(num_max_nodes, num_max_edges, node_size, edge_size, global_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba6820a9",
   "metadata": {},
   "source": [
    "nput_graphs  is a tuple representing the initial data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f39ed36c",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "input_graph_data = utils_tf.data_dicts_to_graphs_tuple([GraphData])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b814cc5",
   "metadata": {},
   "source": [
    "Initializing randomized input data for core\n",
    "note that the core network has as input a double number of features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e74b47a0",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "CoreGraphData = get_fix_graph_data_dict(num_max_nodes, num_max_edges, 2*LATENT_SIZE, 2*LATENT_SIZE, 2*LATENT_SIZE)\n",
    "input_core_graph_data = utils_tf.data_dicts_to_graphs_tuple([CoreGraphData])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d84e16e",
   "metadata": {},
   "source": [
    "nitialize graph data for decoder (input is LATENT_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d65ce0bc",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "DecodeGraphData = get_fix_graph_data_dict(num_max_nodes, num_max_edges, LATENT_SIZE, LATENT_SIZE, LATENT_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f2e3a17",
   "metadata": {},
   "source": [
    "Make prediction of GNN. This will initialize the GNN with weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32e95bbd",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "printMemory(\"before first eval\")\n",
    "output_gn = ep_model(input_graph_data, processing_steps)\n",
    "printMemory(\"after first eval\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e366e830",
   "metadata": {},
   "source": [
    "rint(\"---> Input:\\n\",input_graph_data)\n",
    "rint(\"\\n\\n------> Input core data:\\n\",input_core_graph_data)\n",
    "rint(\"\\n\\n---> Output:\\n\",output_gn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ba974dd",
   "metadata": {},
   "source": [
    "Make SOFIE Model, the model will be made using a maximum number of nodes/edges which are inside GraphData"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b540238",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "encoder = ROOT.TMVA.Experimental.SOFIE.RModel_GraphIndependent.ParseFromMemory(ep_model._encoder._network, GraphData, filename = \"encoder\")\n",
    "encoder.Generate()\n",
    "encoder.OutputGenerated()\n",
    "\n",
    "core = ROOT.TMVA.Experimental.SOFIE.RModel_GNN.ParseFromMemory(ep_model._core._network, CoreGraphData, filename = \"core\")\n",
    "core.Generate()\n",
    "core.OutputGenerated()\n",
    "\n",
    "decoder = ROOT.TMVA.Experimental.SOFIE.RModel_GraphIndependent.ParseFromMemory(ep_model._decoder._network, DecodeGraphData, filename = \"decoder\")\n",
    "decoder.Generate()\n",
    "decoder.OutputGenerated()\n",
    "\n",
    "output_transform = ROOT.TMVA.Experimental.SOFIE.RModel_GraphIndependent.ParseFromMemory(ep_model._output_transform._network, DecodeGraphData, filename = \"output_transform\")\n",
    "output_transform.Generate()\n",
    "output_transform.OutputGenerated()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c0fd191",
   "metadata": {},
   "source": [
    "##################################################################################################################################"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac7b6ed4",
   "metadata": {},
   "source": [
    "enerate data and save in a ROOT TTree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7f93eb8",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "fileOut = ROOT.TFile.Open(\"graph_data.root\",\"RECREATE\")\n",
    "tree = ROOT.TTree(\"gdata\",\"GNN data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e1300b6",
   "metadata": {},
   "source": [
    "eed to store each element since annot store RTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4b30fcb",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "node_data = ROOT.std.vector['float'](num_max_nodes*node_size)\n",
    "edge_data = ROOT.std.vector['float'](num_max_edges*edge_size)\n",
    "global_data = ROOT.std.vector['float'](global_size)\n",
    "receivers =  ROOT.std.vector['int'](num_max_edges)\n",
    "senders = ROOT.std.vector['int'](num_max_edges)\n",
    "outgnn = ROOT.std.vector['float'](3)\n",
    "\n",
    "tree.Branch(\"node_data\", \"std::vector<float>\" , node_data)\n",
    "tree.Branch(\"edge_data\", \"std::vector<float>\" ,  edge_data)\n",
    "tree.Branch(\"global_data\", \"std::vector<float>\" ,  global_data)\n",
    "tree.Branch(\"receivers\", \"std::vector<int>\" ,  receivers)\n",
    "tree.Branch(\"senders\", \"std::vector<int>\" ,  senders)\n",
    "\n",
    "\n",
    "print(\"\\n\\nSaving data in a ROOT File:\")\n",
    "h1 = ROOT.TH1D(\"h1\",\"GraphNet nodes output\",40,1,0)\n",
    "h2 = ROOT.TH1D(\"h2\",\"GraphNet edges output\",40,1,0)\n",
    "h3 = ROOT.TH1D(\"h3\",\"GraphNet global output\",40,1,0)\n",
    "dataset = []\n",
    "for i in range(0,numevts):\n",
    "    graphData = get_dynamic_graph_data_dict(node_size, edge_size, global_size)\n",
    "    s_nodes = graphData['nodes'].size\n",
    "    s_edges = graphData['edges'].size\n",
    "    num_edges = graphData['edges'].shape[0]\n",
    "    tmp = ROOT.std.vector['float'](graphData['nodes'].reshape((graphData['nodes'].size)))\n",
    "    node_data.assign(tmp.begin(),tmp.end())\n",
    "    tmp = ROOT.std.vector['float'](graphData['edges'].reshape((graphData['edges'].size)))\n",
    "    edge_data.assign(tmp.begin(),tmp.end())\n",
    "    tmp = ROOT.std.vector['float'](graphData['globals'].reshape((graphData['globals'].size)))\n",
    "    global_data.assign(tmp.begin(),tmp.end())\n",
    "    #make sure dtype of graphData['receivers'] and senders is int32\n",
    "    tmp = ROOT.std.vector['int'](graphData['receivers'])\n",
    "    receivers.assign(tmp.begin(),tmp.end())\n",
    "    tmp = ROOT.std.vector['int'](graphData['senders'])\n",
    "    senders.assign(tmp.begin(),tmp.end())\n",
    "    if (i < 1 and verbose) :\n",
    "      print(\"Nodes - shape:\",int(node_data.size()/node_size),node_size,\"data: \",node_data)\n",
    "      print(\"Edges - shape:\",num_edges, edge_size,\"data: \", edge_data)\n",
    "      print(\"Globals : \",global_data)\n",
    "      print(\"Receivers : \",receivers)\n",
    "      print(\"Senders   : \",senders)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8809c20",
   "metadata": {},
   "source": [
    "valuate graph net on these events"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac3a38ca",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "    tree.Fill()\n",
    "    tf_graph_data = utils_tf.data_dicts_to_graphs_tuple([graphData])\n",
    "    dataset.append(tf_graph_data)\n",
    "\n",
    "tree.Print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19c5b306",
   "metadata": {},
   "source": [
    "o a first evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e510d6bf",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "printMemory(\"before eval1\")\n",
    "output_gnn = ep_model(dataset[0], processing_steps)\n",
    "printMemory(\"after eval1\")\n",
    "\n",
    "start = time.time()\n",
    "firstEvent = True\n",
    "for tf_graph_data in dataset:\n",
    "    output_gnn = ep_model(tf_graph_data, processing_steps)\n",
    "    output_nodes = output_gnn[-1].nodes.numpy()\n",
    "    output_edges = output_gnn[-1].edges.numpy()\n",
    "    output_globals = output_gnn[-1].globals.numpy()\n",
    "    outgnn[0] = np.mean(output_nodes)\n",
    "    outgnn[1] = np.mean(output_edges)\n",
    "    outgnn[2] = np.mean(output_globals)\n",
    "    h1.Fill(outgnn[0])\n",
    "    h2.Fill(outgnn[1])\n",
    "    h3.Fill(outgnn[2])\n",
    "    if (firstEvent and verbose) :\n",
    "      print(\"Output of first event\")\n",
    "      print(\"nodes data\", output_gnn[-1].nodes.numpy())\n",
    "      print(\"edge data\", output_gnn[-1].edges.numpy())\n",
    "      print(\"global data\", output_gnn[-1].globals.numpy())\n",
    "      firstEvent = False\n",
    "\n",
    "\n",
    "end = time.time()\n",
    "\n",
    "print(\"time to evaluate events\",end-start)\n",
    "printMemory(\"after eval Nevts\")\n",
    "\n",
    "c1 = ROOT.TCanvas()\n",
    "c1.Divide(1,3)\n",
    "c1.cd(1)\n",
    "h1.DrawCopy()\n",
    "c1.cd(2)\n",
    "h2.DrawCopy()\n",
    "c1.cd(3)\n",
    "h3.DrawCopy()\n",
    "\n",
    "tree.Write()\n",
    "h1.Write()\n",
    "h2.Write()\n",
    "h3.Write()\n",
    "fileOut.Close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
