{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1f6cf41d",
   "metadata": {},
   "source": [
    "# ClassificationKeras\n",
    "This tutorial shows how to do classification in TMVA with neural networks\n",
    "trained with keras.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** TMVA Team  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:21 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22c22f5b",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from ROOT import TMVA, TFile, TCut, gROOT\n",
    "from subprocess import call\n",
    "from os.path import isfile\n",
    "\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense\n",
    "from tensorflow.keras.optimizers import SGD\n",
    "\n",
    "\n",
    "def create_model():\n",
    "    # Generate model\n",
    "\n",
    "    # Define model\n",
    "    model = Sequential()\n",
    "    model.add(Dense(64, activation='relu', input_dim=4))\n",
    "    model.add(Dense(2, activation='softmax'))\n",
    "\n",
    "    # Set loss and optimizer\n",
    "    model.compile(loss='categorical_crossentropy',\n",
    "                  optimizer=SGD(learning_rate=0.01), weighted_metrics=['accuracy', ])\n",
    "\n",
    "    # Store model to file\n",
    "    model.save('modelClassification.keras')\n",
    "    model.summary()\n",
    "\n",
    "\n",
    "def run():\n",
    "    with TFile.Open('TMVA_Classification_Keras.root', 'RECREATE') as output, TFile.Open(str(gROOT.GetTutorialDir()) + '/machine_learning/data/tmva_class_example.root') as data:\n",
    "        factory = TMVA.Factory('TMVAClassification', output,\n",
    "                               '!V:!Silent:Color:!DrawProgressBar:Transformations=D,G:AnalysisType=Classification')\n",
    "\n",
    "        signal = data.Get('TreeS')\n",
    "        background = data.Get('TreeB')\n",
    "\n",
    "        dataloader = TMVA.DataLoader('dataset')\n",
    "        for branch in signal.GetListOfBranches():\n",
    "            dataloader.AddVariable(branch.GetName())\n",
    "\n",
    "        dataloader.AddSignalTree(signal, 1.0)\n",
    "        dataloader.AddBackgroundTree(background, 1.0)\n",
    "        dataloader.PrepareTrainingAndTestTree(TCut(''),\n",
    "                                              'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')\n",
    "\n",
    "        # Book methods\n",
    "        factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',\n",
    "                           '!H:!V:Fisher:VarTransform=D,G')\n",
    "        factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',\n",
    "                           'H:!V:VarTransform=D,G:FilenameModel=modelClassification.keras:FilenameTrainedModel=trainedModelClassification.keras:NumEpochs=20:BatchSize=32:LearningRateSchedule=10,0.01;20,0.005')\n",
    "\n",
    "        # Run training, test and evaluation\n",
    "        factory.TrainAllMethods()\n",
    "        factory.TestAllMethods()\n",
    "        factory.EvaluateAllMethods()\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Setup TMVA\n",
    "    TMVA.Tools.Instance()\n",
    "    TMVA.PyMethodBase.PyInitialize()\n",
    "\n",
    "    # Create and store the ML model\n",
    "    create_model()\n",
    "\n",
    "    # Run TMVA\n",
    "    run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}