{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "28873689",
   "metadata": {},
   "source": [
    "# MulticlassKeras\n",
    "This tutorial shows how to do multiclass classification in TMVA with neural\n",
    "networks trained with keras.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** TMVA Team  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:21 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b76b6c07",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from ROOT import TMVA, TFile, TCut, gROOT\n",
    "from os.path import isfile\n",
    "\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense\n",
    "from tensorflow.keras.optimizers import SGD\n",
    "\n",
    "\n",
    "def create_model():\n",
    "    # Define model\n",
    "    model = Sequential()\n",
    "    model.add(Dense(32, activation='relu', input_dim=4))\n",
    "    model.add(Dense(4, activation='softmax'))\n",
    "\n",
    "    # Set loss and optimizer\n",
    "    model.compile(loss='categorical_crossentropy', optimizer=SGD(\n",
    "        learning_rate=0.01), weighted_metrics=['accuracy',])\n",
    "\n",
    "    # Store model to file\n",
    "    model.save('modelMultiClass.keras')\n",
    "    model.summary()\n",
    "\n",
    "\n",
    "def run():\n",
    "    with TFile.Open('TMVA.root', 'RECREATE') as output, TFile.Open('tmva_example_multiple_background.root') as data:\n",
    "        factory = TMVA.Factory('TMVAClassification', output,\n",
    "                               '!V:!Silent:Color:!DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')\n",
    "\n",
    "        signal = data.Get('TreeS')\n",
    "        background0 = data.Get('TreeB0')\n",
    "        background1 = data.Get('TreeB1')\n",
    "        background2 = data.Get('TreeB2')\n",
    "\n",
    "        dataloader = TMVA.DataLoader('dataset')\n",
    "        for branch in signal.GetListOfBranches():\n",
    "            dataloader.AddVariable(branch.GetName())\n",
    "\n",
    "        dataloader.AddTree(signal, 'Signal')\n",
    "        dataloader.AddTree(background0, 'Background_0')\n",
    "        dataloader.AddTree(background1, 'Background_1')\n",
    "        dataloader.AddTree(background2, 'Background_2')\n",
    "        dataloader.PrepareTrainingAndTestTree(TCut(''),\n",
    "                                              'SplitMode=Random:NormMode=NumEvents:!V')\n",
    "\n",
    "        # Book methods\n",
    "        factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',\n",
    "                           '!H:!V:Fisher:VarTransform=D,G')\n",
    "        factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',\n",
    "                           'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.keras:FilenameTrainedModel=trainedModelMultiClass.keras:NumEpochs=20:BatchSize=32')\n",
    "\n",
    "        # Run TMVA\n",
    "        factory.TrainAllMethods()\n",
    "        factory.TestAllMethods()\n",
    "        factory.EvaluateAllMethods()\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Generate model\n",
    "    create_model()\n",
    "\n",
    "    # Setup TMVA\n",
    "    TMVA.Tools.Instance()\n",
    "    TMVA.PyMethodBase.PyInitialize()\n",
    "\n",
    "    # Load data\n",
    "    if not isfile('tmva_example_multiple_background.root'):\n",
    "        createDataMacro = str(gROOT.GetTutorialDir()) + '/machine_learning/createData.C'\n",
    "        print(createDataMacro)\n",
    "        gROOT.ProcessLine('.L {}'.format(createDataMacro))\n",
    "        gROOT.ProcessLine('create_MultipleBackground(4000)')\n",
    "\n",
    "    # Run TMVA\n",
    "    run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}