Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooONNXFunc.h
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 04/2026
5 *
6 * Copyright (c) 2026, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#ifndef RooFit_RooONNXFunc_h
14#define RooFit_RooONNXFunc_h
15
16#include <RooAbsReal.h>
17#include <RooListProxy.h>
18
19#include <any>
20
21class RooONNXFunc : public RooAbsReal {
22public:
23 RooONNXFunc() = default;
24
25 RooONNXFunc(const char *name, const char *title, const std::vector<RooArgList> &inputTensors,
26 const std::string &onnxFile, const std::vector<std::string> &inputNames = {},
27 const std::vector<std::vector<int>> &inputShapes = {});
28
29 RooONNXFunc(const RooONNXFunc &other, const char *newName = nullptr);
30
31 TObject *clone(const char *newName) const override { return new RooONNXFunc(*this, newName); }
32
33 std::size_t nInputTensors() const { return _inputTensors.size(); }
34 RooArgList const &inputTensorList(int iTensor) const { return *(_inputTensors[iTensor]); }
35
36 std::string funcName() const { return _funcName; }
37 std::string outerWrapperName() const { return "TMVA_SOFIE_" + funcName() + "::roo_outer_wrapper"; }
38
39protected:
40 double evaluate() const override;
41
42private:
43 /// Build transient runtime backend on first use.
44 void initialize();
45
46 /// Gather current RooFit inputs into a contiguous feature buffer.
47 void fillInputBuffer() const;
48
49 struct RuntimeCache;
50
51 std::vector<std::unique_ptr<RooListProxy>> _inputTensors; ///< Inputs mapping to flattened input tensors.
52 std::vector<std::uint8_t> _onnxBytes; ///< Persisted ONNX model bytes.
53 std::shared_ptr<RuntimeCache> _runtime; ///<! Transient runtime information.
54 mutable std::vector<float> _inputBuffer; ///<!
55 std::string _funcName; ///<!
56
58};
59
60namespace RooFit::Detail {
61
63 std::any any;
64 void *ptr = nullptr;
65
66 template <class T>
67 void emplace()
68 {
69 any = std::make_any<T>();
70 ptr = std::any_cast<T>(&any);
71 }
72
73 void emplace(std::string const &typeName);
74};
75
76template <class Session_t, class... Inputs>
77void doInferWithSessionVoidPtr(void *session, float *out, Inputs const *...inputs)
78{
79 doInfer(*reinterpret_cast<Session_t *>(session), inputs..., out);
80}
81
82} // namespace RooFit::Detail
83
84#endif
#define ClassDefOverride(name, id)
Definition Rtypes.h:348
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
char name[80]
Definition TGX11.cxx:148
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:63
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooONNXFunc wraps an ONNX model as a RooAbsReal, allowing it to be used as a building block in likeli...
Definition RooONNXFunc.h:21
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
std::vector< std::uint8_t > _onnxBytes
Persisted ONNX model bytes.
Definition RooONNXFunc.h:52
std::string outerWrapperName() const
Definition RooONNXFunc.h:37
std::shared_ptr< RuntimeCache > _runtime
! Transient runtime information.
Definition RooONNXFunc.h:53
std::vector< float > _inputBuffer
!
Definition RooONNXFunc.h:54
TObject * clone(const char *newName) const override
Definition RooONNXFunc.h:31
std::size_t nInputTensors() const
Definition RooONNXFunc.h:33
std::vector< std::unique_ptr< RooListProxy > > _inputTensors
Inputs mapping to flattened input tensors.
Definition RooONNXFunc.h:51
RooArgList const & inputTensorList(int iTensor) const
Definition RooONNXFunc.h:34
RooONNXFunc()=default
void initialize()
Build transient runtime backend on first use.
std::string funcName() const
Definition RooONNXFunc.h:36
void fillInputBuffer() const
Gather current RooFit inputs into a contiguous feature buffer.
std::string _funcName
!
Definition RooONNXFunc.h:55
Mother of all ROOT objects.
Definition TObject.h:42
void doInferWithSessionVoidPtr(void *session, float *out, Inputs const *...inputs)
Definition RooONNXFunc.h:77