{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c47da7a3",
   "metadata": {},
   "source": [
    "# RegressionKeras\n",
    "This tutorial shows how to do regression in TMVA with neural networks\n",
    "trained with keras.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** TMVA Team  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 19, 2026 at 08:21 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "316bb838",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from ROOT import TMVA, TFile, TCut, gROOT\n",
    "from subprocess import call\n",
    "from os.path import isfile\n",
    "\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense\n",
    "from tensorflow.keras.optimizers import SGD\n",
    "\n",
    "\n",
    "def create_model():\n",
    "    # Define model\n",
    "    model = Sequential()\n",
    "    model.add(Dense(64, activation='tanh', input_dim=2))\n",
    "    model.add(Dense(1, activation='linear'))\n",
    "\n",
    "    # Set loss and optimizer\n",
    "    model.compile(loss='mean_squared_error', optimizer=SGD(\n",
    "        learning_rate=0.01), weighted_metrics=[])\n",
    "\n",
    "    # Store model to file\n",
    "    model.save('modelRegression.keras')\n",
    "    model.summary()\n",
    "\n",
    "\n",
    "def run():\n",
    "\n",
    "    with TFile.Open('TMVA_Regression_Keras.root', 'RECREATE') as output, TFile.Open(str(gROOT.GetTutorialDir()) + '/machine_learning/data/tmva_reg_example.root') as data:\n",
    "        factory = TMVA.Factory('TMVARegression', output,\n",
    "                               '!V:!Silent:Color:!DrawProgressBar:Transformations=D,G:AnalysisType=Regression')\n",
    "\n",
    "        tree = data.Get('TreeR')\n",
    "\n",
    "        dataloader = TMVA.DataLoader('dataset')\n",
    "        for branch in tree.GetListOfBranches():\n",
    "            name = branch.GetName()\n",
    "            if name != 'fvalue':\n",
    "                dataloader.AddVariable(name)\n",
    "        dataloader.AddTarget('fvalue')\n",
    "\n",
    "        dataloader.AddRegressionTree(tree, 1.0)\n",
    "        # use only 1000 events since evaluation is very slow (especially on MacOS). Increase it to get meaningful results\n",
    "        dataloader.PrepareTrainingAndTestTree(TCut(''),\n",
    "                                              'nTrain_Regression=1000:SplitMode=Random:NormMode=NumEvents:!V')\n",
    "\n",
    "        # Book methods\n",
    "        factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',\n",
    "                           'H:!V:VarTransform=D,G:FilenameModel=modelRegression.keras:FilenameTrainedModel=trainedModelRegression.keras:NumEpochs=20:BatchSize=32')\n",
    "        factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',\n",
    "                           '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')\n",
    "\n",
    "        # Run TMVA\n",
    "        factory.TrainAllMethods()\n",
    "        factory.TestAllMethods()\n",
    "        factory.EvaluateAllMethods()\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Setup TMVA\n",
    "    TMVA.Tools.Instance()\n",
    "    TMVA.PyMethodBase.PyInitialize()\n",
    "\n",
    "    # Generate model\n",
    "    create_model()\n",
    "\n",
    "    # Run TMVA\n",
    "    run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
