{
"cells": [
{
"cell_type": "markdown",
"id": "69a65833",
"metadata": {},
"source": [
"# tmva003_RReader\n",
"This tutorial shows how to apply with the modern interfaces models saved in\n",
"TMVA XML files.\n",
"\n",
"\n",
"\n",
"\n",
"**Author:** Stefan Wunsch \n",
"This notebook tutorial was automatically generated with ROOTBOOK-izer from the macro found in the ROOT repository on Tuesday, May 19, 2026 at 08:22 PM."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a887c6d6",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:33.617480Z",
"iopub.status.busy": "2026-05-19T20:22:33.617379Z",
"iopub.status.idle": "2026-05-19T20:22:33.943867Z",
"shell.execute_reply": "2026-05-19T20:22:33.943128Z"
}
},
"outputs": [],
"source": [
"using namespace TMVA::Experimental;"
]
},
{
"cell_type": "markdown",
"id": "b4bd0c07",
"metadata": {},
"source": [
" Definition of a helper function: "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3f232ddf",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:33.946113Z",
"iopub.status.busy": "2026-05-19T20:22:33.945991Z",
"iopub.status.idle": "2026-05-19T20:22:33.972707Z",
"shell.execute_reply": "2026-05-19T20:22:33.972106Z"
}
},
"outputs": [],
"source": [
"%%cpp -d\n",
"void train(const std::string &filename)\n",
"{\n",
" // Create factory\n",
" auto output = TFile::Open(\"TMVARR.root\", \"RECREATE\");\n",
" auto factory = new TMVA::Factory(\"tmva003\",\n",
" output, \"!V:!DrawProgressBar:AnalysisType=Classification\");\n",
"\n",
" // Open trees with signal and background events\n",
" auto data = TFile::Open(filename.c_str());\n",
" auto signal = (TTree *)data->Get(\"TreeS\");\n",
" auto background = (TTree *)data->Get(\"TreeB\");\n",
"\n",
" // Add variables and register the trees with the dataloader\n",
" auto dataloader = new TMVA::DataLoader(\"tmva003_BDT\");\n",
" const std::vector variables = {\"var1\", \"var2\", \"var3\", \"var4\"};\n",
" for (const auto &var : variables) {\n",
" dataloader->AddVariable(var);\n",
" }\n",
" dataloader->AddSignalTree(signal, 1.0);\n",
" dataloader->AddBackgroundTree(background, 1.0);\n",
" dataloader->PrepareTrainingAndTestTree(\"\", \"\");\n",
"\n",
" // Train a TMVA method\n",
" factory->BookMethod(dataloader, TMVA::Types::kBDT, \"BDT\", \"!V:!H:NTrees=300:MaxDepth=2\");\n",
" factory->TrainAllMethods();\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "82f10281",
"metadata": {},
"source": [
"First, let's train a model with TMVA."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d3363fd7",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:33.974340Z",
"iopub.status.busy": "2026-05-19T20:22:33.974218Z",
"iopub.status.idle": "2026-05-19T20:22:34.698862Z",
"shell.execute_reply": "2026-05-19T20:22:34.698077Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" DataSetInfo : [tmva003_BDT] : Added class \"Signal\"\n",
" : Add Tree TreeS of type Signal with 6000 events\n",
" DataSetInfo : [tmva003_BDT] : Added class \"Background\"\n",
" : Add Tree TreeB of type Background with 6000 events\n",
" : Dataset[tmva003_BDT] : Class index : 0 name : Signal\n",
" : Dataset[tmva003_BDT] : Class index : 1 name : Background\n",
" Factory : Booking method: BDT\n",
" : \n",
" : Rebuilding Dataset tmva003_BDT\n",
" : Building event vectors for type 2 Signal\n",
" : Dataset[tmva003_BDT] : create input formulas for tree TreeS\n",
" : Building event vectors for type 2 Background\n",
" : Dataset[tmva003_BDT] : create input formulas for tree TreeB\n",
" DataSetFactory : [tmva003_BDT] : Number of events in input trees\n",
" : \n",
" : \n",
" : Dataset[tmva003_BDT] : Weight renormalisation mode: \"EqualNumEvents\": renormalises all event classes ...\n",
" : Dataset[tmva003_BDT] : such that the effective (weighted) number of events in each class is the same \n",
" : Dataset[tmva003_BDT] : (and equals the number of events (entries) given for class=0 )\n",
" : Dataset[tmva003_BDT] : ... i.e. such that Sum[i=1..N_j]{w_i} = N_classA, j=classA, classB, ...\n",
" : Dataset[tmva003_BDT] : ... (note that N_j is the sum of TRAINING events\n",
" : Dataset[tmva003_BDT] : ..... Testing events are not renormalised nor included in the renormalisation factor!)\n",
" : Number of training and testing events\n",
" : ---------------------------------------------------------------------------\n",
" : Signal -- training events : 3000\n",
" : Signal -- testing events : 3000\n",
" : Signal -- training and testing events: 6000\n",
" : Background -- training events : 3000\n",
" : Background -- testing events : 3000\n",
" : Background -- training and testing events: 6000\n",
" : \n",
" DataSetInfo : Correlation matrix (Signal):\n",
" : ----------------------------------------\n",
" : var1 var2 var3 var4\n",
" : var1: +1.000 +0.392 +0.592 +0.822\n",
" : var2: +0.392 +1.000 +0.680 +0.720\n",
" : var3: +0.592 +0.680 +1.000 +0.844\n",
" : var4: +0.822 +0.720 +0.844 +1.000\n",
" : ----------------------------------------\n",
" DataSetInfo : Correlation matrix (Background):\n",
" : ----------------------------------------\n",
" : var1 var2 var3 var4\n",
" : var1: +1.000 +0.854 +0.913 +0.964\n",
" : var2: +0.854 +1.000 +0.925 +0.936\n",
" : var3: +0.913 +0.925 +1.000 +0.970\n",
" : var4: +0.964 +0.936 +0.970 +1.000\n",
" : ----------------------------------------\n",
" DataSetFactory : [tmva003_BDT] : \n",
" : \n",
" Factory : Train all methods\n",
" Factory : [tmva003_BDT] : Create Transformation \"I\" with events from all classes.\n",
" : \n",
" : Transformation, Variable selection : \n",
" : Input : variable 'var1' <---> Output : variable 'var1'\n",
" : Input : variable 'var2' <---> Output : variable 'var2'\n",
" : Input : variable 'var3' <---> Output : variable 'var3'\n",
" : Input : variable 'var4' <---> Output : variable 'var4'\n",
" TFHandler_Factory : Variable Mean RMS [ Min Max ]\n",
" : -----------------------------------------------------------\n",
" : var1: -0.025840 1.6640 [ -4.8874 4.7639 ]\n",
" : var2: -0.018356 1.5781 [ -5.2407 4.5241 ]\n",
" : var3: -0.034388 1.7365 [ -5.3563 4.6430 ]\n",
" : var4: 0.12114 2.1646 [ -6.3160 4.9600 ]\n",
" : -----------------------------------------------------------\n",
" : Ranking input variables (method unspecific)...\n",
" IdTransformation : Ranking result (top variable is best ranked)\n",
" : -----------------------------\n",
" : Rank : Variable : Separation\n",
" : -----------------------------\n",
" : 1 : var4 : 3.521e-01\n",
" : 2 : var3 : 2.907e-01\n",
" : 3 : var1 : 2.648e-01\n",
" : 4 : var2 : 2.255e-01\n",
" : -----------------------------\n",
" Factory : Train method: BDT for Classification\n",
" : \n",
" BDT : #events: (reweighted) sig: 3000 bkg: 3000\n",
" : #events: (unweighted) sig: 3000 bkg: 3000\n",
" : Training 300 Decision Trees ... patience please\n",
" : Elapsed time for training with 6000 events: 0.304 sec \n",
" BDT : [tmva003_BDT] : Evaluation of BDT on training sample (6000 events)\n",
" BDT : [tmva003_BDT] : Evaluation of BDT on training sample (6000 events)\n",
" : Elapsed time for evaluation of 6000 events: 0.0257 sec \n",
" : Elapsed time for evaluation of 6000 events: 0.0258 sec \n",
" : Creating xml weight file: tmva003_BDT/weights/tmva003_BDT.weights.xml\n",
" : Creating standalone class: tmva003_BDT/weights/tmva003_BDT.class.C\n",
" : TMVARR.root:/tmva003_BDT/Method_BDT/BDT\n",
" Factory : Training finished\n",
" : \n",
" : Ranking input variables (method specific)...\n",
" BDT : Ranking result (top variable is best ranked)\n",
" : --------------------------------------\n",
" : Rank : Variable : Variable Importance\n",
" : --------------------------------------\n",
" : 1 : var4 : 4.144e-01\n",
" : 2 : var1 : 2.524e-01\n",
" : 3 : var2 : 1.727e-01\n",
" : 4 : var3 : 1.606e-01\n",
" : --------------------------------------\n",
" Factory : === Destroy and recreate all methods via weight files for testing ===\n",
" : \n",
" : Reading weight file: tmva003_BDT/weights/tmva003_BDT.weights.xml\n"
]
}
],
"source": [
"const std::string filename = std::string(gROOT->GetTutorialDir()) + \"/machine_learning/data/tmva_class_example.root\";\n",
"train(filename);"
]
},
{
"cell_type": "markdown",
"id": "e63724d3",
"metadata": {},
"source": [
"Next, we load the model from the TMVA XML file."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e0797660",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:34.700403Z",
"iopub.status.busy": "2026-05-19T20:22:34.700281Z",
"iopub.status.idle": "2026-05-19T20:22:35.021346Z",
"shell.execute_reply": "2026-05-19T20:22:35.020765Z"
}
},
"outputs": [],
"source": [
"RReader model(\"tmva003_BDT/weights/tmva003_BDT.weights.xml\");"
]
},
{
"cell_type": "markdown",
"id": "3ef53077",
"metadata": {},
"source": [
"In case you need a reminder of the names and order of the variables during\n",
"training, you can ask the model for it."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bfd9e075",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:35.025460Z",
"iopub.status.busy": "2026-05-19T20:22:35.025334Z",
"iopub.status.idle": "2026-05-19T20:22:35.250830Z",
"shell.execute_reply": "2026-05-19T20:22:35.240791Z"
}
},
"outputs": [],
"source": [
"auto variables = model.GetVariableNames();"
]
},
{
"cell_type": "markdown",
"id": "2290a157",
"metadata": {},
"source": [
"The model can now be applied in different scenarios:\n",
"1) Event-by-event inference\n",
"2) Batch inference on data of multiple events\n",
"3) Inference as part of an RDataFrame graph"
]
},
{
"cell_type": "markdown",
"id": "c7303cbe",
"metadata": {},
"source": [
"1) Event-by-event inference\n",
"The event-by-event inference takes the values of the variables as a std::vector.\n",
"Note that the return value is as well a std::vector since the reader\n",
"is also capable to process models with multiple outputs."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3151c705",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:35.252483Z",
"iopub.status.busy": "2026-05-19T20:22:35.252346Z",
"iopub.status.idle": "2026-05-19T20:22:35.457248Z",
"shell.execute_reply": "2026-05-19T20:22:35.456739Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Single-event inference: 0.236928\n",
"\n"
]
}
],
"source": [
"auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});\n",
"std::cout << \"Single-event inference: \" << prediction[0] << \"\\n\\n\";"
]
},
{
"cell_type": "markdown",
"id": "bcd953cd",
"metadata": {},
"source": [
"2) Batch inference on data of multiple events\n",
"For batch inference, the data needs to be structured as a matrix. For this\n",
"purpose, TMVA makes use of the RTensor class. For convenience, we use RDataFrame\n",
"and the AsTensor utility to make the read-out from the ROOT file."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5f939465",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:35.458963Z",
"iopub.status.busy": "2026-05-19T20:22:35.458842Z",
"iopub.status.idle": "2026-05-19T20:22:36.733366Z",
"shell.execute_reply": "2026-05-19T20:22:36.732764Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RTensor input for inference on data of multiple events:\n",
"{ { -1.14361, -0.822373, -0.495426, -0.629427 } { 2.14344, -0.0189228, 0.26703, 1.26749 } { -0.443913, 0.486827, 0.139535, 0.611483 } }\n",
"\n",
"Prediction performed on multiple events: { 0.139826, -0.0423391, 0.224947 }\n",
"\n"
]
}
],
"source": [
"ROOT::RDataFrame df(\"TreeS\", filename);\n",
"auto df2 = df.Range(3); // Read only a small subset of the dataset\n",
"auto x = AsTensor(df2, variables);\n",
"auto y = model.Compute(x);\n",
"\n",
"std::cout << \"RTensor input for inference on data of multiple events:\\n\" << x << \"\\n\\n\";\n",
"std::cout << \"Prediction performed on multiple events: \" << y << \"\\n\\n\";"
]
},
{
"cell_type": "markdown",
"id": "caed95f8",
"metadata": {},
"source": [
"3) Perform inference as part of an RDataFrame graph\n",
"We write a small lambda function that performs for us the inference on\n",
"a dataframe to omit code duplication."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6235c2a8",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:36.734907Z",
"iopub.status.busy": "2026-05-19T20:22:36.734789Z",
"iopub.status.idle": "2026-05-19T20:22:37.252139Z",
"shell.execute_reply": "2026-05-19T20:22:37.251405Z"
}
},
"outputs": [],
"source": [
"auto make_histo = [&](const std::string &treename) {\n",
" ROOT::RDataFrame df(treename, filename);\n",
" auto df2 = df.Define(\"y\", Compute<4, float>(model), variables);\n",
" return df2.Histo1D({treename.c_str(), \";BDT score;N_{Events}\", 30, -0.5, 0.5}, \"y\");\n",
"};\n",
"\n",
"auto sig = make_histo(\"TreeS\");\n",
"auto bkg = make_histo(\"TreeB\");"
]
},
{
"cell_type": "markdown",
"id": "49e0f885",
"metadata": {},
"source": [
"Make plot"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4fae78c1",
"metadata": {
"collapsed": false,
"execution": {
"iopub.execute_input": "2026-05-19T20:22:37.254138Z",
"iopub.status.busy": "2026-05-19T20:22:37.254014Z",
"iopub.status.idle": "2026-05-19T20:22:38.281361Z",
"shell.execute_reply": "2026-05-19T20:22:38.280995Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"\n",
"
\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"gStyle->SetOptStat(0);\n",
"auto c = new TCanvas(\"\", \"\", 800, 800);\n",
"\n",
"sig->SetLineColor(kRed);\n",
"bkg->SetLineColor(kBlue);\n",
"sig->SetLineWidth(2);\n",
"bkg->SetLineWidth(2);\n",
"bkg->Draw(\"HIST\");\n",
"sig->Draw(\"HIST SAME\");\n",
"\n",
"TLegend legend(0.7, 0.7, 0.89, 0.89);\n",
"legend.SetBorderSize(0);\n",
"legend.AddEntry(\"TreeS\", \"Signal\", \"l\");\n",
"legend.AddEntry(\"TreeB\", \"Background\", \"l\");\n",
"legend.Draw();\n",
"\n",
"c->DrawClone();"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ROOT C++",
"language": "c++",
"name": "root"
},
"language_info": {
"codemirror_mode": "text/x-c++src",
"file_extension": ".C",
"mimetype": " text/x-c++src",
"name": "c++"
}
},
"nbformat": 4,
"nbformat_minor": 5
}