Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooFuncWrapper.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#include <RooFuncWrapper.h>
14
15#include <RooAbsData.h>
16#include <RooGlobalFunc.h>
17#include <RooMsgService.h>
18#include <RooRealVar.h>
19#include <RooHelpers.h>
22
23#include <TROOT.h>
24#include <TSystem.h>
25
26RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal const &obj, RooArgSet const &normSet,
27 const RooAbsData *data, RooSimultaneous const *simPdf, bool createGradient)
28 : RooAbsReal{name, title}, _params{"!params", "List of parameters", this}, _hasGradient{createGradient}
29{
30 std::string func;
31
32 // Compile the computation graph for the norm set, such that we also get the
33 // integrals explicitly in the graph.
34 std::unique_ptr<RooAbsReal> pdf{RooFit::Detail::compileForNormSet(obj, normSet)};
35 // Get the parameters.
36 RooArgSet paramSet;
37 obj.getParameters(data ? data->get() : nullptr, paramSet);
38 RooArgSet floatingParamSet;
39 for (RooAbsArg *param : paramSet) {
40 if (!param->isConstant()) {
41 floatingParamSet.add(*param);
42 }
43 }
44
45 // Load the parameters and observables.
46 loadParamsAndData(pdf.get(), floatingParamSet, data, simPdf);
47
48 func = buildCode(*pdf);
49
50 // Declare the function and create its derivative.
51 declareAndDiffFunction(func, createGradient);
52}
53
55 : RooAbsReal(other, name),
56 _params("!params", this, other._params),
57 _func(other._func),
58 _grad(other._grad),
59 _hasGradient(other._hasGradient),
60 _gradientVarBuffer(other._gradientVarBuffer),
61 _observables(other._observables)
62{
63}
64
65void RooFuncWrapper::loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data,
66 RooSimultaneous const *simPdf)
67{
68 // Extract observables
69 std::stack<std::vector<double>> vectorBuffers; // for data loading
70 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
71
72 if (data) {
73 spans = RooFit::Detail::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
74 }
75
76 std::size_t idx = 0;
77 for (auto const &item : spans) {
78 std::size_t n = item.second.size();
79 _obsInfos.emplace(item.first, ObsInfo{idx, n});
80 _observables.reserve(_observables.size() + n);
81 for (std::size_t i = 0; i < n; ++i) {
82 _observables.push_back(item.second[i]);
83 }
84 idx += n;
85 }
86
87 // Extract parameters
88 for (auto *param : paramSet) {
89 if (!dynamic_cast<RooAbsReal *>(param)) {
90 std::stringstream errorMsg;
91 errorMsg << "In creation of function " << GetName()
92 << " wrapper: input param expected to be of type RooAbsReal.";
93 coutE(InputArguments) << errorMsg.str() << std::endl;
94 throw std::runtime_error(errorMsg.str().c_str());
95 }
96 if (spans.find(param) == spans.end()) {
97 _params.add(*param);
98 }
99 }
100 _gradientVarBuffer.resize(_params.size());
101
102 if (head) {
103 _nodeOutputSizes =
104 RooFit::Detail::BatchModeDataHelpers::determineOutputSizes(*head, [&spans](RooFit::Detail::DataKey key) {
105 auto found = spans.find(key);
106 return found != spans.end() ? found->second.size() : 0;
107 });
108 }
109}
110
111void RooFuncWrapper::declareAndDiffFunction(std::string const &funcBody, bool createGradient)
112{
113 static int iFuncWrapper = 0;
114 _funcName = "roo_func_wrapper_" + std::to_string(iFuncWrapper++);
115
116 std::string gradName = _funcName + "_grad_0";
117 std::string requestName = _funcName + "_req";
118 std::string wrapperName = _funcName + "_derivativeWrapper";
119
120 gInterpreter->Declare("#pragma cling optimize(2)");
121
122 // Declare the function
123 std::stringstream bodyWithSigStrm;
124 bodyWithSigStrm << "double " << _funcName << "(double* params, double const* obs) {\n" << funcBody << "\n}";
125 bool comp = gInterpreter->Declare(bodyWithSigStrm.str().c_str());
126 if (!comp) {
127 std::stringstream errorMsg;
128 errorMsg << "Function " << _funcName << " could not be compiled. See above for details.";
129 coutE(InputArguments) << errorMsg.str() << std::endl;
130 throw std::runtime_error(errorMsg.str().c_str());
131 }
132 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
133
134 if (!createGradient)
135 return;
136
137 // Calculate gradient
138 gInterpreter->ProcessLine("#include <Math/CladDerivator.h>");
139 // disable clang-format for making the following code unreadable.
140 // clang-format off
141 std::stringstream requestFuncStrm;
142 requestFuncStrm << "#pragma clad ON\n"
143 "void " << requestName << "() {\n"
144 " clad::gradient(" << _funcName << ", \"params\");\n"
145 "}\n"
146 "#pragma clad OFF";
147 // clang-format on
148 comp = gInterpreter->Declare(requestFuncStrm.str().c_str());
149 if (!comp) {
150 std::stringstream errorMsg;
151 errorMsg << "Function " << GetName() << " could not be differentiated. See above for details.";
152 coutE(InputArguments) << errorMsg.str() << std::endl;
153 throw std::runtime_error(errorMsg.str().c_str());
154 }
155
156 // Build a wrapper over the derivative to hide clad specific types such as 'array_ref'.
157 // disable clang-format for making the following code unreadable.
158 // clang-format off
159 std::stringstream dWrapperStrm;
160 dWrapperStrm << "void " << wrapperName << "(double* params, double const* obs, double* out) {\n"
161 " clad::array_ref<double> cladOut(out, " << _params.size() << ");\n"
162 " " << gradName << "(params, obs, cladOut);\n"
163 "}";
164 // clang-format on
165 gInterpreter->Declare(dWrapperStrm.str().c_str());
166 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine((wrapperName + ";").c_str()));
167}
168
169void RooFuncWrapper::gradient(double *out) const
170{
172 std::fill(out, out + _params.size(), 0.0);
173
174 _grad(_gradientVarBuffer.data(), _observables.data(), out);
175}
176
178{
179 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(),
180 [](RooAbsArg *obj) { return static_cast<RooAbsReal *>(obj)->getVal(); });
181}
182
184{
186
187 return _func(_gradientVarBuffer.data(), _observables.data());
188}
189
190void RooFuncWrapper::gradient(const double *x, double *g) const
191{
192 std::fill(g, g + _params.size(), 0.0);
193
194 _grad(const_cast<double *>(x), _observables.data(), g);
195}
196
197std::string RooFuncWrapper::buildCode(RooAbsReal const &head)
198{
200
201 // First update the result variable of params in the compute graph to in[<position>].
202 int idx = 0;
203 for (RooAbsArg *param : _params) {
204 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
205 idx++;
206 }
207
208 for (auto const &item : _obsInfos) {
209 const char *name = item.first->GetName();
210 // If the observable is scalar, set name to the start idx. else, store
211 // the start idx and later set the the name to obs[start_idx + curr_idx],
212 // here curr_idx is defined by a loop producing parent node.
213 if (item.second.size == 1) {
214 ctx.addResult(name, "obs[" + std::to_string(item.second.idx) + "]");
215 } else {
216 ctx.addResult(name, "obs");
217 ctx.addVecObs(name, item.second.idx);
218 }
219 }
220
221 return ctx.assembleCode(ctx.getResult(head));
222}
223
224/// @brief Prints the squashed code body to console.
226{
227 gInterpreter->ProcessLine(_funcName.c_str());
228}
229
230/// @brief Prints the derivative code body to console.
232{
233 gInterpreter->ProcessLine((_funcName + "_grad_0").c_str());
234}
#define g(i)
Definition RSha256.hxx:105
RooAbsReal * _func
Pointer to original input function.
#define coutE(a)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
RooFit::OwningPtr< RooArgSet > getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
Storage_t::size_type size() const
TIterator begin()
TIterator end() and range-based for loops.")
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
A class to maintain the context for squashing of RooFit models into code.
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
std::vector< double > _gradientVarBuffer
void dumpGradient()
Prints the derivative code body to console.
void declareAndDiffFunction(std::string const &funcBody, bool createGradient)
std::map< RooFit::Detail::DataKey, std::size_t > _nodeOutputSizes
std::map< RooFit::Detail::DataKey, ObsInfo > _obsInfos
std::string buildCode(RooAbsReal const &head)
void gradient(double *out) const override
double(*)(double *, double const *) Func
void loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
std::vector< double > _observables
void(*)(double *, double const *, double *) Grad
std::string _funcName
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
RooListProxy _params
void updateGradientVarBuffer() const
void dumpCode()
Prints the squashed code body to console.
RooFuncWrapper(const char *name, const char *title, RooAbsReal const &obj, RooArgSet const &normSet, const RooAbsData *data, RooSimultaneous const *simPdf, bool createGradient)
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
std::unique_ptr< T > compileForNormSet(T const &arg, RooArgSet const &normSet)