{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "36c6a997",
   "metadata": {},
   "source": [
    "# tmva102_Testing\n",
    "This tutorial illustrates how you can test a trained BDT model using the fast\n",
    "tree inference engine offered by TMVA and external tools such as scikit-learn.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** Stefan Wunsch  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:22 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b99b7329",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import ROOT\n",
    "from tmva101_Training import load_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c01ac7c",
   "metadata": {},
   "source": [
    "Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d20a5156",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "x, y_true, w = load_data(\"test_signal.root\", \"test_background.root\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b07b7544",
   "metadata": {},
   "source": [
    "Load trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "349d3fef",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "File = \"tmva101.root\"\n",
    "\n",
    "bdt = ROOT.TMVA.Experimental.RBDT(\"myBDT\", File)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9645370a",
   "metadata": {},
   "source": [
    "Make prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3510744",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "y_pred = bdt.Compute(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2348389c",
   "metadata": {},
   "source": [
    "Compute ROC using sklearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "034c7736",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import auc, roc_curve\n",
    "\n",
    "false_positive_rate, true_positive_rate, _ = roc_curve(y_true, y_pred, sample_weight=w)\n",
    "score = auc(false_positive_rate, true_positive_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08f4648d",
   "metadata": {},
   "source": [
    "Plot ROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a766b0ff",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "c = ROOT.TCanvas(\"roc\", \"\", 600, 600)\n",
    "g = ROOT.TGraph(len(false_positive_rate), false_positive_rate, true_positive_rate)\n",
    "g.SetTitle(\"AUC = {:.2f}\".format(score))\n",
    "g.SetLineWidth(3)\n",
    "g.SetLineColor(\"kRed\")\n",
    "g.Draw(\"AC\")\n",
    "g.GetXaxis().SetRangeUser(0, 1)\n",
    "g.GetYaxis().SetRangeUser(0, 1)\n",
    "g.GetXaxis().SetTitle(\"False-positive rate\")\n",
    "g.GetYaxis().SetTitle(\"True-positive rate\")\n",
    "c.Draw()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}