{ "cells": [ { "cell_type": "markdown", "id": "69a65833", "metadata": {}, "source": [ "# tmva003_RReader\n", "This tutorial shows how to apply with the modern interfaces models saved in\n", "TMVA XML files.\n", "\n", "\n", "\n", "\n", "**Author:** Stefan Wunsch \n", "This notebook tutorial was automatically generated with ROOTBOOK-izer from the macro found in the ROOT repository on Tuesday, May 19, 2026 at 08:22 PM." ] }, { "cell_type": "code", "execution_count": 1, "id": "a887c6d6", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:33.617480Z", "iopub.status.busy": "2026-05-19T20:22:33.617379Z", "iopub.status.idle": "2026-05-19T20:22:33.943867Z", "shell.execute_reply": "2026-05-19T20:22:33.943128Z" } }, "outputs": [], "source": [ "using namespace TMVA::Experimental;" ] }, { "cell_type": "markdown", "id": "b4bd0c07", "metadata": {}, "source": [ " Definition of a helper function: " ] }, { "cell_type": "code", "execution_count": 2, "id": "3f232ddf", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:33.946113Z", "iopub.status.busy": "2026-05-19T20:22:33.945991Z", "iopub.status.idle": "2026-05-19T20:22:33.972707Z", "shell.execute_reply": "2026-05-19T20:22:33.972106Z" } }, "outputs": [], "source": [ "%%cpp -d\n", "void train(const std::string &filename)\n", "{\n", " // Create factory\n", " auto output = TFile::Open(\"TMVARR.root\", \"RECREATE\");\n", " auto factory = new TMVA::Factory(\"tmva003\",\n", " output, \"!V:!DrawProgressBar:AnalysisType=Classification\");\n", "\n", " // Open trees with signal and background events\n", " auto data = TFile::Open(filename.c_str());\n", " auto signal = (TTree *)data->Get(\"TreeS\");\n", " auto background = (TTree *)data->Get(\"TreeB\");\n", "\n", " // Add variables and register the trees with the dataloader\n", " auto dataloader = new TMVA::DataLoader(\"tmva003_BDT\");\n", " const std::vector variables = {\"var1\", \"var2\", \"var3\", \"var4\"};\n", " for (const auto &var : variables) {\n", " dataloader->AddVariable(var);\n", " }\n", " dataloader->AddSignalTree(signal, 1.0);\n", " dataloader->AddBackgroundTree(background, 1.0);\n", " dataloader->PrepareTrainingAndTestTree(\"\", \"\");\n", "\n", " // Train a TMVA method\n", " factory->BookMethod(dataloader, TMVA::Types::kBDT, \"BDT\", \"!V:!H:NTrees=300:MaxDepth=2\");\n", " factory->TrainAllMethods();\n", "}" ] }, { "cell_type": "markdown", "id": "82f10281", "metadata": {}, "source": [ "First, let's train a model with TMVA." ] }, { "cell_type": "code", "execution_count": 3, "id": "d3363fd7", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:33.974340Z", "iopub.status.busy": "2026-05-19T20:22:33.974218Z", "iopub.status.idle": "2026-05-19T20:22:34.698862Z", "shell.execute_reply": "2026-05-19T20:22:34.698077Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "
DataSetInfo : [tmva003_BDT] : Added class \"Signal\"\n", " : Add Tree TreeS of type Signal with 6000 events\n", "
DataSetInfo : [tmva003_BDT] : Added class \"Background\"\n", " : Add Tree TreeB of type Background with 6000 events\n", " : Dataset[tmva003_BDT] : Class index : 0 name : Signal\n", " : Dataset[tmva003_BDT] : Class index : 1 name : Background\n", "
Factory : Booking method: BDT\n", " : \n", " : Rebuilding Dataset tmva003_BDT\n", " : Building event vectors for type 2 Signal\n", " : Dataset[tmva003_BDT] : create input formulas for tree TreeS\n", " : Building event vectors for type 2 Background\n", " : Dataset[tmva003_BDT] : create input formulas for tree TreeB\n", "
DataSetFactory : [tmva003_BDT] : Number of events in input trees\n", " : \n", " : \n", " : Dataset[tmva003_BDT] : Weight renormalisation mode: \"EqualNumEvents\": renormalises all event classes ...\n", " : Dataset[tmva003_BDT] : such that the effective (weighted) number of events in each class is the same \n", " : Dataset[tmva003_BDT] : (and equals the number of events (entries) given for class=0 )\n", " : Dataset[tmva003_BDT] : ... i.e. such that Sum[i=1..N_j]{w_i} = N_classA, j=classA, classB, ...\n", " : Dataset[tmva003_BDT] : ... (note that N_j is the sum of TRAINING events\n", " : Dataset[tmva003_BDT] : ..... Testing events are not renormalised nor included in the renormalisation factor!)\n", " : Number of training and testing events\n", " : ---------------------------------------------------------------------------\n", " : Signal -- training events : 3000\n", " : Signal -- testing events : 3000\n", " : Signal -- training and testing events: 6000\n", " : Background -- training events : 3000\n", " : Background -- testing events : 3000\n", " : Background -- training and testing events: 6000\n", " : \n", "
DataSetInfo : Correlation matrix (Signal):\n", " : ----------------------------------------\n", " : var1 var2 var3 var4\n", " : var1: +1.000 +0.392 +0.592 +0.822\n", " : var2: +0.392 +1.000 +0.680 +0.720\n", " : var3: +0.592 +0.680 +1.000 +0.844\n", " : var4: +0.822 +0.720 +0.844 +1.000\n", " : ----------------------------------------\n", "
DataSetInfo : Correlation matrix (Background):\n", " : ----------------------------------------\n", " : var1 var2 var3 var4\n", " : var1: +1.000 +0.854 +0.913 +0.964\n", " : var2: +0.854 +1.000 +0.925 +0.936\n", " : var3: +0.913 +0.925 +1.000 +0.970\n", " : var4: +0.964 +0.936 +0.970 +1.000\n", " : ----------------------------------------\n", "
DataSetFactory : [tmva003_BDT] : \n", " : \n", "
Factory : Train all methods\n", "
Factory : [tmva003_BDT] : Create Transformation \"I\" with events from all classes.\n", " : \n", "
: Transformation, Variable selection : \n", " : Input : variable 'var1' <---> Output : variable 'var1'\n", " : Input : variable 'var2' <---> Output : variable 'var2'\n", " : Input : variable 'var3' <---> Output : variable 'var3'\n", " : Input : variable 'var4' <---> Output : variable 'var4'\n", "
TFHandler_Factory : Variable Mean RMS [ Min Max ]\n", " : -----------------------------------------------------------\n", " : var1: -0.025840 1.6640 [ -4.8874 4.7639 ]\n", " : var2: -0.018356 1.5781 [ -5.2407 4.5241 ]\n", " : var3: -0.034388 1.7365 [ -5.3563 4.6430 ]\n", " : var4: 0.12114 2.1646 [ -6.3160 4.9600 ]\n", " : -----------------------------------------------------------\n", " : Ranking input variables (method unspecific)...\n", "
IdTransformation : Ranking result (top variable is best ranked)\n", " : -----------------------------\n", " : Rank : Variable : Separation\n", " : -----------------------------\n", " : 1 : var4 : 3.521e-01\n", " : 2 : var3 : 2.907e-01\n", " : 3 : var1 : 2.648e-01\n", " : 4 : var2 : 2.255e-01\n", " : -----------------------------\n", "
Factory : Train method: BDT for Classification\n", " : \n", "
BDT : #events: (reweighted) sig: 3000 bkg: 3000\n", " : #events: (unweighted) sig: 3000 bkg: 3000\n", " : Training 300 Decision Trees ... patience please\n", " : Elapsed time for training with 6000 events: 0.304 sec \n", "
BDT : [tmva003_BDT] : Evaluation of BDT on training sample (6000 events)\n", "
BDT : [tmva003_BDT] : Evaluation of BDT on training sample (6000 events)\n", " : Elapsed time for evaluation of 6000 events: 0.0257 sec \n", " : Elapsed time for evaluation of 6000 events: 0.0258 sec \n", " : Creating xml weight file: tmva003_BDT/weights/tmva003_BDT.weights.xml\n", " : Creating standalone class: tmva003_BDT/weights/tmva003_BDT.class.C\n", " : TMVARR.root:/tmva003_BDT/Method_BDT/BDT\n", "
Factory : Training finished\n", " : \n", " : Ranking input variables (method specific)...\n", "
BDT : Ranking result (top variable is best ranked)\n", " : --------------------------------------\n", " : Rank : Variable : Variable Importance\n", " : --------------------------------------\n", " : 1 : var4 : 4.144e-01\n", " : 2 : var1 : 2.524e-01\n", " : 3 : var2 : 1.727e-01\n", " : 4 : var3 : 1.606e-01\n", " : --------------------------------------\n", "
Factory : === Destroy and recreate all methods via weight files for testing ===\n", " : \n", " : Reading weight file: tmva003_BDT/weights/tmva003_BDT.weights.xml\n" ] } ], "source": [ "const std::string filename = std::string(gROOT->GetTutorialDir()) + \"/machine_learning/data/tmva_class_example.root\";\n", "train(filename);" ] }, { "cell_type": "markdown", "id": "e63724d3", "metadata": {}, "source": [ "Next, we load the model from the TMVA XML file." ] }, { "cell_type": "code", "execution_count": 4, "id": "e0797660", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:34.700403Z", "iopub.status.busy": "2026-05-19T20:22:34.700281Z", "iopub.status.idle": "2026-05-19T20:22:35.021346Z", "shell.execute_reply": "2026-05-19T20:22:35.020765Z" } }, "outputs": [], "source": [ "RReader model(\"tmva003_BDT/weights/tmva003_BDT.weights.xml\");" ] }, { "cell_type": "markdown", "id": "3ef53077", "metadata": {}, "source": [ "In case you need a reminder of the names and order of the variables during\n", "training, you can ask the model for it." ] }, { "cell_type": "code", "execution_count": 5, "id": "bfd9e075", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:35.025460Z", "iopub.status.busy": "2026-05-19T20:22:35.025334Z", "iopub.status.idle": "2026-05-19T20:22:35.250830Z", "shell.execute_reply": "2026-05-19T20:22:35.240791Z" } }, "outputs": [], "source": [ "auto variables = model.GetVariableNames();" ] }, { "cell_type": "markdown", "id": "2290a157", "metadata": {}, "source": [ "The model can now be applied in different scenarios:\n", "1) Event-by-event inference\n", "2) Batch inference on data of multiple events\n", "3) Inference as part of an RDataFrame graph" ] }, { "cell_type": "markdown", "id": "c7303cbe", "metadata": {}, "source": [ "1) Event-by-event inference\n", "The event-by-event inference takes the values of the variables as a std::vector.\n", "Note that the return value is as well a std::vector since the reader\n", "is also capable to process models with multiple outputs." ] }, { "cell_type": "code", "execution_count": 6, "id": "3151c705", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:35.252483Z", "iopub.status.busy": "2026-05-19T20:22:35.252346Z", "iopub.status.idle": "2026-05-19T20:22:35.457248Z", "shell.execute_reply": "2026-05-19T20:22:35.456739Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Single-event inference: 0.236928\n", "\n" ] } ], "source": [ "auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});\n", "std::cout << \"Single-event inference: \" << prediction[0] << \"\\n\\n\";" ] }, { "cell_type": "markdown", "id": "bcd953cd", "metadata": {}, "source": [ "2) Batch inference on data of multiple events\n", "For batch inference, the data needs to be structured as a matrix. For this\n", "purpose, TMVA makes use of the RTensor class. For convenience, we use RDataFrame\n", "and the AsTensor utility to make the read-out from the ROOT file." ] }, { "cell_type": "code", "execution_count": 7, "id": "5f939465", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:35.458963Z", "iopub.status.busy": "2026-05-19T20:22:35.458842Z", "iopub.status.idle": "2026-05-19T20:22:36.733366Z", "shell.execute_reply": "2026-05-19T20:22:36.732764Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RTensor input for inference on data of multiple events:\n", "{ { -1.14361, -0.822373, -0.495426, -0.629427 } { 2.14344, -0.0189228, 0.26703, 1.26749 } { -0.443913, 0.486827, 0.139535, 0.611483 } }\n", "\n", "Prediction performed on multiple events: { 0.139826, -0.0423391, 0.224947 }\n", "\n" ] } ], "source": [ "ROOT::RDataFrame df(\"TreeS\", filename);\n", "auto df2 = df.Range(3); // Read only a small subset of the dataset\n", "auto x = AsTensor(df2, variables);\n", "auto y = model.Compute(x);\n", "\n", "std::cout << \"RTensor input for inference on data of multiple events:\\n\" << x << \"\\n\\n\";\n", "std::cout << \"Prediction performed on multiple events: \" << y << \"\\n\\n\";" ] }, { "cell_type": "markdown", "id": "caed95f8", "metadata": {}, "source": [ "3) Perform inference as part of an RDataFrame graph\n", "We write a small lambda function that performs for us the inference on\n", "a dataframe to omit code duplication." ] }, { "cell_type": "code", "execution_count": 8, "id": "6235c2a8", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:36.734907Z", "iopub.status.busy": "2026-05-19T20:22:36.734789Z", "iopub.status.idle": "2026-05-19T20:22:37.252139Z", "shell.execute_reply": "2026-05-19T20:22:37.251405Z" } }, "outputs": [], "source": [ "auto make_histo = [&](const std::string &treename) {\n", " ROOT::RDataFrame df(treename, filename);\n", " auto df2 = df.Define(\"y\", Compute<4, float>(model), variables);\n", " return df2.Histo1D({treename.c_str(), \";BDT score;N_{Events}\", 30, -0.5, 0.5}, \"y\");\n", "};\n", "\n", "auto sig = make_histo(\"TreeS\");\n", "auto bkg = make_histo(\"TreeB\");" ] }, { "cell_type": "markdown", "id": "49e0f885", "metadata": {}, "source": [ "Make plot" ] }, { "cell_type": "code", "execution_count": 9, "id": "4fae78c1", "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-05-19T20:22:37.254138Z", "iopub.status.busy": "2026-05-19T20:22:37.254014Z", "iopub.status.idle": "2026-05-19T20:22:38.281361Z", "shell.execute_reply": "2026-05-19T20:22:38.280995Z" } }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gStyle->SetOptStat(0);\n", "auto c = new TCanvas(\"\", \"\", 800, 800);\n", "\n", "sig->SetLineColor(kRed);\n", "bkg->SetLineColor(kBlue);\n", "sig->SetLineWidth(2);\n", "bkg->SetLineWidth(2);\n", "bkg->Draw(\"HIST\");\n", "sig->Draw(\"HIST SAME\");\n", "\n", "TLegend legend(0.7, 0.7, 0.89, 0.89);\n", "legend.SetBorderSize(0);\n", "legend.AddEntry(\"TreeS\", \"Signal\", \"l\");\n", "legend.AddEntry(\"TreeB\", \"Background\", \"l\");\n", "legend.Draw();\n", "\n", "c->DrawClone();" ] } ], "metadata": { "kernelspec": { "display_name": "ROOT C++", "language": "c++", "name": "root" }, "language_info": { "codemirror_mode": "text/x-c++src", "file_extension": ".C", "mimetype": " text/x-c++src", "name": "c++" } }, "nbformat": 4, "nbformat_minor": 5 }