Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooONNXFunction.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 04/2026
5 *
6 * Copyright (c) 2026, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#include <RooONNXFunction.h>
14
15#include <TInterpreter.h>
16#include <TSystem.h>
17
18#include <fstream>
19#include <mutex>
20
21/**
22 \file RooONNXFunction.cxx
23 \class RooONNXFunction
24 \ingroup Roofit
25
26 RooONNXFunction wraps an ONNX model as a RooAbsReal, allowing it to be used as
27 a building block in likelihoods, fits, and statistical analyses without
28 additional boilerplate code. The class supports models with **one or more
29 statically-shaped input tensors** and a **single scalar output**. The class
30 was designed to share workspaces with neural functions for combined fits in
31 RooFit-based frameworks written in C++. Therefore, the RooONNXFunction doesn't
32 depend on any Python packages and fully supports ROOT IO,
33
34 The ONNX model is evaluated through compiled C++ code generated at runtime
35 using **TMVA SOFIE**. Automatic differentiation is supported via **Clad**,
36 allowing RooFit to access analytical gradients for fast minimization with
37 Minuit 2.
38
39 The ONNX model is stored internally as a byte payload and serialized together
40 with the RooONNXFunction object using ROOT I/O. Upon reading from a file or
41 workspace, the runtime backend is rebuilt automatically.
42
43 ### Input handling
44
45 The model inputs are provided as a list of tensors, where each tensor is
46 represented by a RooArgList of RooAbsReal objects. The order of the inputs
47 defines the feature ordering passed to the ONNX model.
48 Optionally, users can validate that the ONNX model has the expected input
49
50 ### Example (C++)
51
52 \code
53 // Define input variables
54 RooRealVar x{"x", "x", 0.0};
55 RooRealVar y{"y", "y", 0.0};
56 RooRealVar z{"z", "z", 0.0};
57
58 // Construct ONNX function, building the std::vector<RooArgList> in-place
59 RooONNXFunction func{
60 "func", "func",
61 {{x, y}, {z}},
62 "model.onnx"
63 };
64
65 // Evaluate
66 double val = func.getVal();
67 std::cout << "Model output: " << val << std::endl;
68 \endcode
69
70 ### Example (Python)
71
72 \code{.py}
73 import ROOT
74
75 # Define variables
76 x = ROOT.RooRealVar("x", "x", 0.0)
77 y = ROOT.RooRealVar("y", "y", 0.0)
78 z = ROOT.RooRealVar("z", "z", 0.0)
79
80 # Create ONNX function
81 func = ROOT.RooONNXFunction(
82 "func", "func",
83 [[x, y], [z]],
84 "model.onnx"
85 )
86
87 # Evaluate
88 print("Model output:", func.getVal())
89 \endcode
90
91 */
92
93namespace {
94
95std::vector<std::uint8_t> fileToBytes(std::string const &filePath)
96{
97 // Read file into byte vector
98 std::ifstream file(filePath, std::ios::binary);
99 if (!file) {
100 std::ostringstream os;
101 os << "failed to open file '" << filePath << "'";
102 throw std::runtime_error(os.str());
103 }
104
105 file.seekg(0, std::ios::end);
106 const std::streamsize size = file.tellg();
107 file.seekg(0, std::ios::beg);
108
109 if (size <= 0) {
110 std::ostringstream os;
111 os << "file '" << filePath << "' is empty";
112 throw std::runtime_error(os.str());
113 }
114
115 std::vector<std::uint8_t> bytes(static_cast<std::size_t>(size));
116 file.read(reinterpret_cast<char *>(bytes.data()), size);
117
118 if (!file) {
119 std::ostringstream os;
120 os << "error while reading file '" << filePath << "'";
121 throw std::runtime_error(os.str());
122 }
123
124 return bytes;
125}
126
127template <typename Fn>
128Fn resolveLazy(std::string const &name, const char *code)
129{
130 static Fn fn = nullptr;
131 static std::once_flag flag;
132
133 std::call_once(flag, [&] {
134 // Try to declare the code
135 if (!gInterpreter->Declare(code)) {
136 throw std::runtime_error(std::string("ROOT JIT Declare failed for code defining ") + name);
137 }
138
139 // Try to resolve the symbol
140 void *symbol = reinterpret_cast<void *>(gInterpreter->ProcessLine((name + ";").c_str()));
141
142 if (!symbol) {
143 throw std::runtime_error(std::string("ROOT JIT failed to resolve symbol: ") + name);
144 }
145
146 fn = reinterpret_cast<Fn>(symbol);
147
148 if (!fn) {
149 throw std::runtime_error(std::string("ROOT JIT produced null function pointer for: ") + name);
150 }
151 });
152
153 return fn;
154}
155
156template <typename T>
157std::string toPtrString(T *ptr, std::string const &castType)
158{
159 return TString::Format("reinterpret_cast<%s>(0x%zx)", (castType + "*").c_str(), reinterpret_cast<std::size_t>(ptr))
160 .Data();
161}
162
163} // namespace
164
165void RooFit::Detail::AnyWithVoidPtr::emplace(std::string const &typeName)
166{
167 auto anyPtrSession = toPtrString(this, "RooFit::Detail::AnyWithVoidPtr");
168 gInterpreter->ProcessLine((anyPtrSession + "->emplace<" + typeName + ">();").c_str());
169}
170
178
179/**
180 Construct a RooONNXFunction from an ONNX model file.
181
182 \param name Name of the RooFit object
183 \param title Title of the RooFit object
184 \param inputTensors Vector of RooArgList, each representing one input tensor.
185 The variables in each RooArgList match to each flattened input tensor.
186 \param onnxFile Path to the ONNX model file. The file is read and stored
187 internally as a byte payload for persistence with RooWorkspace.
188 \param inputNames Optional list of ONNX input node names. If provided, these
189 are used to validate that the ONNX model has the structure expected by
190 your RooFit code.
191 \param inputShapes Optional list of tensor shapes corresponding to each input
192 tensor. If provided, these are used to validate that the ONNX models
193 input tensors have the shape that you expect. If omitted, only the
194 total size of each tensor is checked.
195 */
196RooONNXFunction::RooONNXFunction(const char *name, const char *title, const std::vector<RooArgList> &inputTensors,
197 const std::string &onnxFile, const std::vector<std::string> & /*inputNames*/,
198 const std::vector<std::vector<int>> & /*inputShapes*/)
199 : RooAbsReal{name, title}, _onnxBytes{fileToBytes(onnxFile)}
200{
201 for (std::size_t i = 0; i < inputTensors.size(); ++i) {
202 std::string istr = std::to_string(i);
203 _inputTensors.emplace_back(
204 std::make_unique<RooListProxy>(("!inputs_" + istr).c_str(), ("Input tensor " + istr).c_str(), this));
205 _inputTensors.back()->addTyped<RooAbsReal>(inputTensors[i]);
206 }
207}
208
210 : RooAbsReal{other, newName}, _onnxBytes{other._onnxBytes}, _runtime{other._runtime}
211{
212 for (std::size_t i = 0; i < other._inputTensors.size(); ++i) {
213 _inputTensors.emplace_back(std::make_unique<RooListProxy>("!inputs", this, *other._inputTensors[i]));
214 }
215}
216
218{
219 _inputBuffer.clear();
220 _inputBuffer.reserve(_inputTensors.size());
221
222 for (auto const &tensorList : _inputTensors) {
224 _inputBuffer.push_back(static_cast<float>(real->getVal(tensorList->nset())));
225 }
226 }
227}
228
230{
231 if (_runtime) {
232 return;
233 }
234
235 _runtime = std::make_unique<RuntimeCache>();
236
237 // We are jitting the SOFIE invocation lazily at runtime, to avoid the
238 // link-time dependency to the SOFIE parser library.
239 if (gSystem->Load("libROOTTMVASofieParser") < 0) {
240 throw std::runtime_error("RooONNXFunction: cannot load ONNX file since SOFIE ONNX parser is missing."
241 " Please build ROOT with tmva-sofie=ON.");
242 }
243 using OnnxToCpp = std::string (*)(std::uint8_t const *, std::size_t, const char *);
244 auto onnxToCppWithSofie = resolveLazy<OnnxToCpp>("_RooONNXFunction_onnxToCppWithSofie",
245 R"(
246#include "TMVA/RModelParser_ONNX.hxx"
247
248std::string _RooONNXFunction_onnxToCppWithSofie(std::uint8_t const *onnxBytes, std::size_t onnxBytesSize, const char *outputName)
249{
250 namespace SOFIE = TMVA::Experimental::SOFIE;
251
252 std::string buffer{reinterpret_cast<const char *>(onnxBytes), onnxBytesSize};
253 std::istringstream stream{buffer};
254
255 SOFIE::RModel rmodel = SOFIE::RModelParser_ONNX{}.Parse(stream, outputName);
256 rmodel.SetOptimizationLevel(SOFIE::OptimizationLevel::kBasic);
257 rmodel.Generate(SOFIE::Options::kNoWeightFile);
258
259 std::stringstream ss{};
260 rmodel.PrintGenerated(ss);
261 return ss.str();
262}
263)");
264
265 static int counter = 0;
266 _funcName = "roo_onnx_func_" + std::to_string(counter);
267 std::string namespaceName = "TMVA_SOFIE_" + _funcName + "";
268 counter++;
269
270 std::string modelCode = onnxToCppWithSofie(_onnxBytes.data(), _onnxBytes.size(), _funcName.c_str());
271 gInterpreter->Declare(modelCode.c_str());
272
273 // Declare string to the interpreter, where the %%NAMESPACE%% placeholder
274 // will first be replaced by the namespace for the emitted code.
275 auto declareWithNamespace = [&](std::string codeTemplate) {
276 const std::string placeholder = "%%NAMESPACE%%";
277 size_t pos = 0;
278
279 while ((pos = codeTemplate.find(placeholder, pos)) != std::string::npos) {
280 codeTemplate.replace(pos, placeholder.length(), namespaceName);
281 pos += namespaceName.length();
282 }
283
284 gInterpreter->Declare(codeTemplate.c_str());
285 };
286
288
289namespace %%NAMESPACE%% {
290
291float roo_inner_wrapper(Session const &session, float const *input)
292{
293 float out = 0.;
294 doInfer(session, input, &out);
295 return out;
296}
297
298float roo_wrapper(Session const &session, float const *input)
299{
300 return roo_inner_wrapper(session, input);
301}
302
303} // namespace %%NAMESPACE%%
304
305)");
306
307 std::string sessionName = "::TMVA_SOFIE_" + _funcName + "::Session";
308
309 _runtime->_session.emplace(sessionName);
310 auto ptrSession = toPtrString(_runtime->_session.ptr, sessionName);
311
312 std::stringstream ss2;
313 ss2 << "static_cast<void (*)(void *, float const *, float *)>(RooFit::Detail::doInferWithSessionVoidPtr<"
314 << sessionName << ">" << ");";
315 _runtime->_func = reinterpret_cast<RuntimeCache::Func>(gInterpreter->ProcessLine(ss2.str().c_str()));
316
317 // hardcode the gradient for now
318 _runtime->_d_session.emplace(sessionName);
319 auto ptrDSession = toPtrString(_runtime->_d_session.ptr, sessionName);
320
321 gInterpreter->Declare("#include <Math/CladDerivator.h>");
322
323 gInterpreter->ProcessLine(("clad::gradient(" + namespaceName + "::roo_wrapper, \"input\");").c_str());
324
326namespace %%NAMESPACE%% {
327
328double roo_outer_wrapper(double const *input) {
329 auto &session = *)" +
330 ptrSession + R"(;
331 float inputFlt[inputTensorDims[0].total_size()];
332 for (std::size_t i = 0; i < std::size(inputFlt); ++i) {
333 inputFlt[i] = input[i];
334 }
335 return roo_inner_wrapper(session, inputFlt);
336}
337
338} // namespace %%NAMESPACE%%
339
340namespace clad::custom_derivatives {
341
342namespace %%NAMESPACE%% {
343
344void roo_outer_wrapper_pullback(double const *input, double d_y, double *d_input) {
345
346 using namespace ::%%NAMESPACE%%;
347
348 float inputFlt[inputTensorDims[0].total_size()];
349 float d_inputFlt[::std::size(inputFlt)];
350 for (::std::size_t i = 0; i < ::std::size(inputFlt); ++i) {
351 inputFlt[i] = input[i];
352 d_inputFlt[i] = d_input[i];
353 }
354 auto *session = )" + ptrSession +
355 R"(;
356 auto *d_session = )" +
357 ptrDSession + R"(;
358 roo_inner_wrapper_pullback(*session, inputFlt, d_y, d_session, d_inputFlt);
359 for (::std::size_t i = 0; i < ::std::size(inputFlt); ++i) {
360 d_input[i] += d_inputFlt[i];
361 }
362}
363
364} // namespace %%NAMESPACE%%
365
366} // namespace clad::custom_derivatives
367
368)");
369}
370
371double RooONNXFunction::evaluate() const
372{
373 initialize();
375
376 float out = 0.f;
377 _runtime->_func(_runtime->_session.ptr, _inputBuffer.data(), &out);
378 return static_cast<double>(out);
379}
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t bytes
char name[80]
Definition TGX11.cxx:145
#define gInterpreter
R__EXTERN TSystem * gSystem
Definition TSystem.h:582
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:63
RooONNXFunction wraps an ONNX model as a RooAbsReal, allowing it to be used as a building block in li...
std::shared_ptr< RuntimeCache > _runtime
! Transient runtime information.
std::string _funcName
!
void initialize() const
Build transient runtime backend on first use.
std::vector< std::unique_ptr< RooListProxy > > _inputTensors
Inputs mapping to flattened input tensors.
RooONNXFunction()=default
std::vector< std::uint8_t > _onnxBytes
Persisted ONNX model bytes.
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
std::vector< float > _inputBuffer
!
void fillInputBuffer() const
Gather current RooFit inputs into a contiguous feature buffer.
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2384
virtual int Load(const char *module, const char *entry="", Bool_t system=kFALSE)
Load a shared library.
Definition TSystem.cxx:1868
RooFit::Detail::AnyWithVoidPtr _session
RooFit::Detail::AnyWithVoidPtr _d_session