Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooFuncWrapper.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#include <RooFuncWrapper.h>
14
15#include <RooAbsData.h>
16#include <RooGlobalFunc.h>
17#include <RooMsgService.h>
18#include <RooRealVar.h>
19#include <RooHelpers.h>
22
23#include <TROOT.h>
24#include <TSystem.h>
25
26RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, std::string const &funcBody,
27 RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf,
28 bool createGradient)
29 : RooAbsReal{name, title}, _params{"!params", "List of parameters", this}, _hasGradient{createGradient}
30{
31 // Declare the function and create its derivative.
32 declareAndDiffFunction(name, funcBody, createGradient);
33
34 // Load the parameters and observables.
35 loadParamsAndData(name, nullptr, paramSet, data, simPdf);
36}
37
38RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal const &obj, RooArgSet const &normSet,
39 const RooAbsData *data, RooSimultaneous const *simPdf, bool createGradient)
40 : RooAbsReal{name, title}, _params{"!params", "List of parameters", this}, _hasGradient{createGradient}
41{
42 std::string func;
43
44 // Compile the computation graph for the norm set, such that we also get the
45 // integrals explicitly in the graph.
46 std::unique_ptr<RooAbsReal> pdf{RooFit::Detail::compileForNormSet(obj, normSet)};
47 // Get the parameters.
48 RooArgSet paramSet;
49 obj.getParameters(data ? data->get() : nullptr, paramSet);
50 RooArgSet floatingParamSet;
51 for (RooAbsArg *param : paramSet) {
52 if (!param->isConstant()) {
53 floatingParamSet.add(*param);
54 }
55 }
56
57 // Load the parameters and observables.
58 loadParamsAndData(name, pdf.get(), floatingParamSet, data, simPdf);
59
60 func = buildCode(*pdf);
61
62 // Declare the function and create its derivative.
63 declareAndDiffFunction(name, func, createGradient);
64}
65
67 : RooAbsReal(other, name),
68 _params("!params", this, other._params),
69 _func(other._func),
70 _grad(other._grad),
71 _hasGradient(other._hasGradient),
72 _gradientVarBuffer(other._gradientVarBuffer),
73 _observables(other._observables)
74{
75}
76
77void RooFuncWrapper::loadParamsAndData(std::string funcName, RooAbsArg const *head, RooArgSet const &paramSet,
78 const RooAbsData *data, RooSimultaneous const *simPdf)
79{
80 // Extract observables
81 std::stack<std::vector<double>> vectorBuffers; // for data loading
82 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
83
84 if (data) {
85 spans = RooFit::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
86 }
87
88 std::size_t idx = 0;
89 for (auto const &item : spans) {
90 std::size_t n = item.second.size();
91 _obsInfos.emplace(item.first, ObsInfo{idx, n});
92 _observables.reserve(_observables.size() + n);
93 for (std::size_t i = 0; i < n; ++i) {
94 _observables.push_back(item.second[i]);
95 }
96 idx += n;
97 }
98
99 // Extract parameters
100 for (auto *param : paramSet) {
101 if (!dynamic_cast<RooAbsReal *>(param)) {
102 std::stringstream errorMsg;
103 errorMsg << "In creation of function " << funcName
104 << " wrapper: input param expected to be of type RooAbsReal.";
105 coutE(InputArguments) << errorMsg.str() << std::endl;
106 throw std::runtime_error(errorMsg.str().c_str());
107 }
108 if (spans.find(param) == spans.end()) {
109 _params.add(*param);
110 }
111 }
112 _gradientVarBuffer.resize(_params.size());
113
114 if (head) {
115 _nodeOutputSizes =
117 auto found = spans.find(key);
118 return found != spans.end() ? found->second.size() : 0;
119 });
120 }
121}
122
123void RooFuncWrapper::declareAndDiffFunction(std::string funcName, std::string const &funcBody, bool createGradient)
124{
125 std::string gradName = funcName + "_grad_0";
126 std::string requestName = funcName + "_req";
127 std::string wrapperName = funcName + "_derivativeWrapper";
128
129 gInterpreter->Declare("#pragma cling optimize(2)");
130
131 // Declare the function
132 std::stringstream bodyWithSigStrm;
133 bodyWithSigStrm << "double " << funcName << "(double* params, double const* obs) {\n" << funcBody << "\n}";
134 bool comp = gInterpreter->Declare(bodyWithSigStrm.str().c_str());
135 if (!comp) {
136 std::stringstream errorMsg;
137 errorMsg << "Function " << funcName << " could not be compiled. See above for details.";
138 coutE(InputArguments) << errorMsg.str() << std::endl;
139 throw std::runtime_error(errorMsg.str().c_str());
140 }
141 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((funcName + ";").c_str()));
142
143 if (!createGradient)
144 return;
145
146 // Calculate gradient
147 gInterpreter->ProcessLine("#include <Math/CladDerivator.h>");
148 // disable clang-format for making the following code unreadable.
149 // clang-format off
150 std::stringstream requestFuncStrm;
151 requestFuncStrm << "#pragma clad ON\n"
152 "void " << requestName << "() {\n"
153 " clad::gradient(" << funcName << ", \"params\");\n"
154 "}\n"
155 "#pragma clad OFF";
156 // clang-format on
157 comp = gInterpreter->Declare(requestFuncStrm.str().c_str());
158 if (!comp) {
159 std::stringstream errorMsg;
160 errorMsg << "Function " << funcName << " could not be differentiated. See above for details.";
161 coutE(InputArguments) << errorMsg.str() << std::endl;
162 throw std::runtime_error(errorMsg.str().c_str());
163 }
164
165 // Build a wrapper over the derivative to hide clad specific types such as 'array_ref'.
166 // disable clang-format for making the following code unreadable.
167 // clang-format off
168 std::stringstream dWrapperStrm;
169 dWrapperStrm << "void " << wrapperName << "(double* params, double const* obs, double* out) {\n"
170 " clad::array_ref<double> cladOut(out, " << _params.size() << ");\n"
171 " " << gradName << "(params, obs, cladOut);\n"
172 "}";
173 // clang-format on
174 gInterpreter->Declare(dWrapperStrm.str().c_str());
175 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine((wrapperName + ";").c_str()));
176}
177
178void RooFuncWrapper::gradient(double *out) const
179{
181 std::fill(out, out + _params.size(), 0.0);
182
183 _grad(_gradientVarBuffer.data(), _observables.data(), out);
184}
185
187{
188 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(),
189 [](RooAbsArg *obj) { return static_cast<RooAbsReal *>(obj)->getVal(); });
190}
191
193{
195
196 return _func(_gradientVarBuffer.data(), _observables.data());
197}
198
199void RooFuncWrapper::gradient(const double *x, double *g) const
200{
201 std::fill(g, g + _params.size(), 0.0);
202
203 _grad(const_cast<double *>(x), _observables.data(), g);
204}
205
206std::string RooFuncWrapper::buildCode(RooAbsReal const &head)
207{
209
210 // First update the result variable of params in the compute graph to in[<position>].
211 int idx = 0;
212 for (RooAbsArg *param : _params) {
213 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
214 idx++;
215 }
216
217 for (auto const &item : _obsInfos) {
218 const char *name = item.first->GetName();
219 // If the observable is scalar, set name to the start idx. else, store
220 // the start idx and later set the the name to obs[start_idx + curr_idx],
221 // here curr_idx is defined by a loop producing parent node.
222 if (item.second.size == 1) {
223 ctx.addResult(name, "obs[" + std::to_string(item.second.idx) + "]");
224 } else {
225 ctx.addResult(name, "obs");
226 ctx.addVecObs(name, item.second.idx);
227 }
228 }
229
230 return ctx.assembleCode(ctx.getResult(head));
231}
232
233/// @brief Prints the squashed code body to console.
235{
236 gInterpreter->ProcessLine(fName);
237}
238
239/// @brief Prints the derivative code body to console.
241{
242 gInterpreter->ProcessLine(fName + "_grad_0");
243}
#define g(i)
Definition RSha256.hxx:105
#define coutE(a)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition RooAbsArg.h:80
RooFit::OwningPtr< RooArgSet > getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
const_iterator end() const
Storage_t::size_type size() const
const_iterator begin() const
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition RooAbsReal.h:59
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
A class to maintain the context for squashing of RooFit models into code.
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
std::vector< double > _gradientVarBuffer
void dumpGradient()
Prints the derivative code body to console.
RooFuncWrapper(const char *name, const char *title, std::string const &funcBody, RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf, bool createGradient)
std::map< RooFit::Detail::DataKey, std::size_t > _nodeOutputSizes
std::map< RooFit::Detail::DataKey, ObsInfo > _obsInfos
std::string buildCode(RooAbsReal const &head)
void gradient(double *out) const override
double(*)(double *, double const *) Func
std::vector< double > _observables
void(*)(double *, double const *, double *) Grad
void declareAndDiffFunction(std::string funcName, std::string const &funcBody, bool createGradient)
void loadParamsAndData(std::string funcName, RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
RooListProxy _params
void updateGradientVarBuffer() const
void dumpCode()
Prints the squashed code body to console.
RooSimultaneous facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
TString fName
Definition TNamed.h:32
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
std::map< RooFit::Detail::DataKey, std::span< const double > > getDataSpans(RooAbsData const &data, std::string const &rangeName, RooSimultaneous const *simPdf, bool skipZeroWeights, bool takeGlobalObservablesFromData, std::stack< std::vector< double > > &buffers)
Extract all content from a RooFit datasets as a map of spans.
std::map< RooFit::Detail::DataKey, std::size_t > determineOutputSizes(RooAbsArg const &topNode, std::function< std::size_t(RooFit::Detail::DataKey)> const &inputSizeFunc)
Figure out the output size for each node in the computation graph that leads up to the top node,...
std::unique_ptr< T > compileForNormSet(T const &arg, RooArgSet const &normSet)