Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooONNXFunc.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 04/2026
5 *
6 * Copyright (c) 2026, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#include <RooONNXFunc.h>
14
15#include <TBuffer.h>
16#include <TInterpreter.h>
17#include <TSystem.h>
18
19#include <fstream>
20#include <mutex>
21
22/**
23 \file RooONNXFunc.cxx
24 \class RooONNXFunc
25 \ingroup Roofit
26
27 RooONNXFunc wraps an ONNX model as a RooAbsReal, allowing it to be used as
28 a building block in likelihoods, fits, and statistical analyses without
29 additional boilerplate code. The class supports models with **one or more
30 statically-shaped input tensors** and a **single scalar output**. The class
31 was designed to share workspaces with neural functions for combined fits in
32 RooFit-based frameworks written in C++. Therefore, the RooONNXFunc doesn't
33 depend on any Python packages and fully supports ROOT IO,
34
35 The ONNX model is evaluated through compiled C++ code generated at runtime
36 using **TMVA SOFIE**. Automatic differentiation is supported via
37 [Clad](https://github.com/vgvassilev/clad), allowing RooFit to access
38 analytical gradients for fast minimization with Minuit 2.
39
40 The ONNX model is stored internally as a byte payload and serialized together
41 with the RooONNXFunc object using ROOT I/O. Upon reading from a file or
42 workspace, the runtime backend is rebuilt automatically.
43
44 ### Input handling
45
46 The model inputs are provided as a list of tensors, where each tensor is
47 represented by a RooArgList of RooAbsReal objects. The order of the inputs
48 defines the feature ordering passed to the ONNX model.
49 Optionally, users can validate that the ONNX model has the expected input
50
51 ### Example (C++)
52
53 \code
54 // Define input variables
55 RooRealVar x{"x", "x", 0.0};
56 RooRealVar y{"y", "y", 0.0};
57 RooRealVar z{"z", "z", 0.0};
58
59 // Construct ONNX function, building the std::vector<RooArgList> in-place
60 RooONNXFunc func{
61 "func", "func",
62 {{x, y}, {z}},
63 "model.onnx"
64 };
65
66 // Evaluate
67 double val = func.getVal();
68 std::cout << "Model output: " << val << std::endl;
69 \endcode
70
71 ### Example (Python)
72
73 \code{.py}
74 import ROOT
75
76 # Define variables
77 x = ROOT.RooRealVar("x", "x", 0.0)
78 y = ROOT.RooRealVar("y", "y", 0.0)
79 z = ROOT.RooRealVar("z", "z", 0.0)
80
81 # Create ONNX function
82 func = ROOT.RooONNXFunc(
83 "func", "func",
84 [[x, y], [z]],
85 "model.onnx"
86 )
87
88 # Evaluate
89 print("Model output:", func.getVal())
90 \endcode
91
92 */
93
94namespace {
95
96std::vector<std::uint8_t> fileToBytes(std::string const &filePath)
97{
98 // Read file into byte vector
99 std::ifstream file(filePath, std::ios::binary);
100 if (!file) {
101 std::ostringstream os;
102 os << "failed to open file '" << filePath << "'";
103 throw std::runtime_error(os.str());
104 }
105
106 file.seekg(0, std::ios::end);
107 const std::streamsize size = file.tellg();
108 file.seekg(0, std::ios::beg);
109
110 if (size <= 0) {
111 std::ostringstream os;
112 os << "file '" << filePath << "' is empty";
113 throw std::runtime_error(os.str());
114 }
115
116 std::vector<std::uint8_t> bytes(static_cast<std::size_t>(size));
117 file.read(reinterpret_cast<char *>(bytes.data()), size);
118
119 if (!file) {
120 std::ostringstream os;
121 os << "error while reading file '" << filePath << "'";
122 throw std::runtime_error(os.str());
123 }
124
125 return bytes;
126}
127
128template <typename Fn>
129Fn resolveLazy(std::string const &name, const char *code)
130{
131 static Fn fn = nullptr;
132 static std::once_flag flag;
133
134 std::call_once(flag, [&] {
135 // Try to declare the code
136 if (!gInterpreter->Declare(code)) {
137 throw std::runtime_error(std::string("ROOT JIT Declare failed for code defining ") + name);
138 }
139
140 // Try to resolve the symbol
141 void *symbol = reinterpret_cast<void *>(gInterpreter->ProcessLine((name + ";").c_str()));
142
143 if (!symbol) {
144 throw std::runtime_error(std::string("ROOT JIT failed to resolve symbol: ") + name);
145 }
146
147 fn = reinterpret_cast<Fn>(symbol);
148
149 if (!fn) {
150 throw std::runtime_error(std::string("ROOT JIT produced null function pointer for: ") + name);
151 }
152 });
153
154 return fn;
155}
156
157template <typename T>
158std::string toPtrString(T *ptr, std::string const &castType)
159{
160 return TString::Format("reinterpret_cast<%s>(0x%zx)", (castType + "*").c_str(), reinterpret_cast<std::size_t>(ptr))
161 .Data();
162}
163
164// Expression for the offset into a flat input buffer to the i-th tensor:
165// i=0: "0"; i=1: "inputTensorDims[0].total_size()";
166// i=2: "inputTensorDims[0].total_size() + inputTensorDims[1].total_size()"
167std::string flatOffsetExpr(std::size_t i)
168{
169 if (i == 0)
170 return "0";
171 std::string out;
172 for (std::size_t j = 0; j < i; ++j) {
173 if (j > 0)
174 out += " + ";
175 out += "inputTensorDims[" + std::to_string(j) + "].total_size()";
176 }
177 return out;
178}
179
180} // namespace
181
182void RooFit::Detail::AnyWithVoidPtr::emplace(std::string const &typeName)
183{
184 auto anyPtrSession = toPtrString(this, "RooFit::Detail::AnyWithVoidPtr");
185 gInterpreter->ProcessLine((anyPtrSession + "->emplace<" + typeName + ">();").c_str());
186}
187
189 /// Uniform thunk signature regardless of input-tensor count.
190 /// Args: (Session*, output, flat input buffer).
191 using Func = void (*)(void *, float *, float const *);
192
196};
197
198/**
199 Construct a RooONNXFunc from an ONNX model file.
200
201 \param name Name of the RooFit object
202 \param title Title of the RooFit object
203 \param inputTensors Vector of RooArgList, each representing one input tensor.
204 The variables in each RooArgList match to each flattened input tensor.
205 \param onnxFile Path to the ONNX model file. The file is read and stored
206 internally as a byte payload for persistence with RooWorkspace.
207 \param inputNames Optional list of ONNX input node names. If provided, these
208 are used to validate that the ONNX model has the structure expected by
209 your RooFit code.
210 \param inputShapes Optional list of tensor shapes corresponding to each input
211 tensor. If provided, these are used to validate that the ONNX models
212 input tensors have the shape that you expect. If omitted, only the
213 total size of each tensor is checked.
214 */
215RooONNXFunc::RooONNXFunc(const char *name, const char *title, const std::vector<RooArgList> &inputTensors,
216 const std::string &onnxFile, const std::vector<std::string> & /*inputNames*/,
217 const std::vector<std::vector<int>> & /*inputShapes*/)
218 : RooAbsReal{name, title}, _onnxBytes{fileToBytes(onnxFile)}
219{
220 initialize();
221
222 for (std::size_t i = 0; i < inputTensors.size(); ++i) {
223 std::string istr = std::to_string(i);
224 _inputTensors.emplace_back(
225 std::make_unique<RooListProxy>(("!inputs_" + istr).c_str(), ("Input tensor " + istr).c_str(), this));
226 _inputTensors.back()->addTyped<RooAbsReal>(inputTensors[i]);
227 }
228}
229
231 : RooAbsReal{other, newName}, _onnxBytes{other._onnxBytes}, _runtime{other._runtime}
232{
233 for (std::size_t i = 0; i < other._inputTensors.size(); ++i) {
234 _inputTensors.emplace_back(std::make_unique<RooListProxy>("!inputs", this, *other._inputTensors[i]));
235 }
236}
237
239{
240 _inputBuffer.clear();
241 _inputBuffer.reserve(_inputTensors.size());
242
243 for (auto const &tensorList : _inputTensors) {
245 _inputBuffer.push_back(static_cast<float>(real->getVal(tensorList->nset())));
246 }
247 }
248}
249
251{
252 if (_runtime) {
253 return;
254 }
255
256 _runtime = std::make_unique<RuntimeCache>();
257
258 // We are jitting the SOFIE invocation lazily at runtime, to avoid the
259 // link-time dependency to the SOFIE parser library.
260 if (gSystem->Load("libROOTTMVASofieParser") < 0) {
261 throw std::runtime_error("RooONNXFunc: cannot load ONNX file since SOFIE ONNX parser is missing."
262 " Please build ROOT with tmva-sofie=ON.");
263 }
264 using OnnxToCpp = std::string (*)(std::uint8_t const *, std::size_t, const char *);
265 auto onnxToCppWithSofie = resolveLazy<OnnxToCpp>("_RooONNXFunc_onnxToCppWithSofie",
266 R"(
267#include "TMVA/RModelParser_ONNX.hxx"
268
269std::string _RooONNXFunc_onnxToCppWithSofie(std::uint8_t const *onnxBytes, std::size_t onnxBytesSize, const char *outputName)
270{
271 namespace SOFIE = TMVA::Experimental::SOFIE;
272
273 std::string buffer{reinterpret_cast<const char *>(onnxBytes), onnxBytesSize};
274 std::istringstream stream{buffer};
275
276 SOFIE::RModel rmodel = SOFIE::RModelParser_ONNX{}.Parse(stream, outputName);
277 rmodel.SetOptimizationLevel(SOFIE::OptimizationLevel::kBasic);
278 rmodel.Generate(SOFIE::Options::kNoWeightFile);
279
280 std::stringstream ss{};
281 rmodel.PrintGenerated(ss);
282 return ss.str();
283}
284)");
285
286 static int counter = 0;
287 _funcName = "roo_onnx_func_" + std::to_string(counter);
288 std::string namespaceName = "TMVA_SOFIE_" + _funcName + "";
289 counter++;
290
291 std::string modelCode = onnxToCppWithSofie(_onnxBytes.data(), _onnxBytes.size(), _funcName.c_str());
292 gInterpreter->Declare(modelCode.c_str());
293
294 auto nInputTensors = static_cast<unsigned long>(
295 gInterpreter->ProcessLine(("std::size(" + namespaceName + "::inputTensorDims);").c_str()));
296
297 // Per-input-tensor parameter / argument lists used by the JIT'd code below.
298 std::string innerParams; // "float const *input0, float const *input1, ..."
299 std::string innerArgs; // "input0, input1, ..."
300 std::string outerDoubleParams; // "double const *input0, double const *input1, ..."
301 std::string cladInputs; // "input0, input1, ..." (for clad::gradient param spec)
302 for (std::size_t i = 0; i < nInputTensors; ++i) {
303 std::string istr = std::to_string(i);
304 if (i > 0) {
305 innerParams += ", ";
306 innerArgs += ", ";
307 outerDoubleParams += ", ";
308 cladInputs += ", ";
309 }
310 innerParams += "float const *input" + istr;
311 innerArgs += "input" + istr;
312 outerDoubleParams += "double const *input" + istr;
313 cladInputs += "input" + istr;
314 }
315
316 // Non-template inner / wrapper functions, generated with the right number of inputs.
317 {
318 std::ostringstream ss;
319 ss << "namespace " << namespaceName << " {\n\n"
320 << "float roo_inner_wrapper(Session const &session, " << innerParams << ") {\n"
321 << " float out = 0.;\n"
322 << " doInfer(session, " << innerArgs << ", &out);\n"
323 << " return out;\n"
324 << "}\n\n"
325 << "float roo_wrapper(Session const &session, " << innerParams << ") {\n"
326 << " return roo_inner_wrapper(session, " << innerArgs << ");\n"
327 << "}\n\n"
328 << "} // namespace " << namespaceName << "\n";
329 gInterpreter->Declare(ss.str().c_str());
330 }
331
332 // Evaluation thunk with a uniform signature regardless of the input-tensor count:
333 // takes a flat float buffer and splits it into the per-tensor pointers expected by
334 // SOFIE's doInfer.
335 {
336 std::ostringstream ss;
337 ss << "namespace " << namespaceName << " {\n"
338 << "void roo_eval_thunk(void *session_void, float *out, float const *flat_input) {\n"
339 << " auto *session = reinterpret_cast<Session *>(session_void);\n"
340 << " doInfer(*session";
341 for (std::size_t i = 0; i < nInputTensors; ++i) {
342 ss << ", flat_input + (" << flatOffsetExpr(i) << ")";
343 }
344 ss << ", out);\n"
345 << "}\n"
346 << "} // namespace " << namespaceName << "\n";
347 gInterpreter->Declare(ss.str().c_str());
348 }
349
350 std::string sessionName = "::TMVA_SOFIE_" + _funcName + "::Session";
351
352 _runtime->_session.emplace(sessionName);
353 auto ptrSession = toPtrString(_runtime->_session.ptr, sessionName);
354
355 _runtime->_func = reinterpret_cast<RuntimeCache::Func>(gInterpreter->ProcessLine(
356 ("static_cast<void(*)(void *, float *, float const *)>(" + namespaceName + "::roo_eval_thunk);").c_str()));
357
358 // hardcode the gradient for now
359 _runtime->_d_session.emplace(sessionName);
360 auto ptrDSession = toPtrString(_runtime->_d_session.ptr, sessionName);
361
362 gInterpreter->Declare("#include <Math/CladDerivator.h>");
363
364 gInterpreter->ProcessLine(("clad::gradient(" + namespaceName + "::roo_wrapper, \"" + cladInputs + "\");").c_str());
365
366 // The codegen call site (CodegenImpl::codegenImpl(RooONNXFunc)) passes one
367 // double-array argument per input tensor. Emit roo_outer_wrapper and the matching
368 // custom-derivative pullback with the corresponding number of parameters.
369 {
370 std::ostringstream ss;
371 ss << "namespace " << namespaceName << " {\n\n"
372 << "double roo_outer_wrapper(" << outerDoubleParams << ") {\n"
373 << " auto &session = *" << ptrSession << ";\n";
374 for (std::size_t i = 0; i < nInputTensors; ++i) {
375 ss << " float inputFlt" << i << "[inputTensorDims[" << i << "].total_size()];\n"
376 << " for (std::size_t i = 0; i < std::size(inputFlt" << i << "); ++i) {\n"
377 << " inputFlt" << i << "[i] = input" << i << "[i];\n"
378 << " }\n";
379 }
380 ss << " return roo_inner_wrapper(session";
381 for (std::size_t i = 0; i < nInputTensors; ++i) {
382 ss << ", inputFlt" << i;
383 }
384 ss << ");\n"
385 << "}\n\n"
386 << "} // namespace " << namespaceName << "\n\n"
387 << "namespace clad::custom_derivatives {\n"
388 << "namespace " << namespaceName << " {\n\n"
389 << "void roo_outer_wrapper_pullback(" << outerDoubleParams << ", double d_y";
390 for (std::size_t i = 0; i < nInputTensors; ++i) {
391 ss << ", double *d_input" << i;
392 }
393 ss << ") {\n"
394 << " using namespace ::" << namespaceName << ";\n";
395 for (std::size_t i = 0; i < nInputTensors; ++i) {
396 ss << " float inputFlt" << i << "[inputTensorDims[" << i << "].total_size()];\n"
397 << " float d_inputFlt" << i << "[::std::size(inputFlt" << i << ")];\n"
398 << " for (::std::size_t i = 0; i < ::std::size(inputFlt" << i << "); ++i) {\n"
399 << " inputFlt" << i << "[i] = input" << i << "[i];\n"
400 << " d_inputFlt" << i << "[i] = d_input" << i << "[i];\n"
401 << " }\n";
402 }
403 ss << " auto *session = " << ptrSession << ";\n"
404 << " auto *d_session = " << ptrDSession << ";\n"
405 << " roo_inner_wrapper_pullback(*session";
406 for (std::size_t i = 0; i < nInputTensors; ++i) {
407 ss << ", inputFlt" << i;
409 ss << ", d_y, d_session";
410 for (std::size_t i = 0; i < nInputTensors; ++i) {
411 ss << ", d_inputFlt" << i;
412 }
413 ss << ");\n";
414 for (std::size_t i = 0; i < nInputTensors; ++i) {
415 ss << " for (::std::size_t i = 0; i < ::std::size(inputFlt" << i << "); ++i) {\n"
416 << " d_input" << i << "[i] += d_inputFlt" << i << "[i];\n"
417 << " }\n";
418 }
419 ss << "}\n\n"
420 << "} // namespace " << namespaceName << "\n"
421 << "} // namespace clad::custom_derivatives\n";
422 gInterpreter->Declare(ss.str().c_str());
423 }
424}
425
426double RooONNXFunc::evaluate() const
427{
429
430 float out = 0.f;
431 _runtime->_func(_runtime->_session.ptr, &out, _inputBuffer.data());
432 return static_cast<double>(out);
433}
434
436{
437 if (R__b.IsReading()) {
438 R__b.ReadClassBuffer(RooONNXFunc::Class(), this);
439 this->initialize();
440 } else {
442 }
443}
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t bytes
char name[80]
Definition TGX11.cxx:148
#define gInterpreter
R__EXTERN TSystem * gSystem
Definition TSystem.h:582
friend void RooRefArray::Streamer(TBuffer &)
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:63
RooONNXFunc wraps an ONNX model as a RooAbsReal, allowing it to be used as a building block in likeli...
Definition RooONNXFunc.h:21
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
std::vector< std::uint8_t > _onnxBytes
Persisted ONNX model bytes.
Definition RooONNXFunc.h:52
std::shared_ptr< RuntimeCache > _runtime
! Transient runtime information.
Definition RooONNXFunc.h:53
std::vector< float > _inputBuffer
!
Definition RooONNXFunc.h:54
std::size_t nInputTensors() const
Definition RooONNXFunc.h:33
static TClass * Class()
std::vector< std::unique_ptr< RooListProxy > > _inputTensors
Inputs mapping to flattened input tensors.
Definition RooONNXFunc.h:51
RooONNXFunc()=default
void initialize()
Build transient runtime backend on first use.
void fillInputBuffer() const
Gather current RooFit inputs into a contiguous feature buffer.
std::string _funcName
!
Definition RooONNXFunc.h:55
Buffer base class used for serializing objects.
Definition TBuffer.h:43
virtual Int_t WriteClassBuffer(const TClass *cl, void *pointer)=0
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
virtual int Load(const char *module, const char *entry="", Bool_t system=kFALSE)
Load a shared library.
Definition TSystem.cxx:1872
RooFit::Detail::AnyWithVoidPtr _d_session
RooFit::Detail::AnyWithVoidPtr _session