Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooFuncWrapper.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Garima Singh, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
13#include <RooFuncWrapper.h>
14
15#include <RooAbsData.h>
18#include <RooFit/Evaluator.h>
19#include <RooGlobalFunc.h>
20#include <RooHelpers.h>
21#include <RooMsgService.h>
22#include <RooRealVar.h>
23#include <RooSimultaneous.h>
24#include "RooEvaluatorWrapper.h"
25
26#include <TROOT.h>
27#include <TSystem.h>
28
29#include <fstream>
30
31namespace RooFit {
32
33namespace Experimental {
34
35RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data,
36 RooSimultaneous const *simPdf, bool useEvaluator)
37 : RooAbsReal{name, title}, _params{"!params", "List of parameters", this}, _useEvaluator{useEvaluator}
38{
39 if (_useEvaluator) {
40 _absReal = std::make_unique<RooEvaluatorWrapper>(obj, const_cast<RooAbsData *>(data), false, "", simPdf, false);
41 }
42
43 std::string func;
44
45 // Get the parameters.
46 RooArgSet paramSet;
47 obj.getParameters(data ? data->get() : nullptr, paramSet);
48 RooArgSet floatingParamSet;
49 for (RooAbsArg *param : paramSet) {
50 if (!param->isConstant()) {
51 floatingParamSet.add(*param);
52 }
53 }
54
55 // Load the parameters and observables.
56 loadParamsAndData(&obj, floatingParamSet, data, simPdf);
57
58 func = buildCode(obj);
59
60 declareToInterpreter("#pragma cling optimize(2)");
61
62 // Declare the function and create its derivative.
64 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
65}
66
68 : RooAbsReal(other, name),
69 _params("!params", this, other._params),
70 _funcName(other._funcName),
71 _func(other._func),
72 _grad(other._grad),
73 _hasGradient(other._hasGradient),
74 _gradientVarBuffer(other._gradientVarBuffer),
75 _observables(other._observables)
76{
77}
78
79void RooFuncWrapper::loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data,
80 RooSimultaneous const *simPdf)
81{
82 // Extract observables
83 std::stack<std::vector<double>> vectorBuffers; // for data loading
84 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
85
86 if (data) {
87 spans = RooFit::Detail::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
88 }
89
90 std::size_t idx = 0;
91 for (auto const &item : spans) {
92 std::size_t n = item.second.size();
93 _obsInfos.emplace(item.first, ObsInfo{idx, n});
94 _observables.reserve(_observables.size() + n);
95 for (std::size_t i = 0; i < n; ++i) {
96 _observables.push_back(item.second[i]);
97 }
98 idx += n;
99 }
100
101 // Extract parameters
102 for (auto *param : paramSet) {
103 if (!dynamic_cast<RooAbsReal *>(param)) {
104 std::stringstream errorMsg;
105 errorMsg << "In creation of function " << GetName()
106 << " wrapper: input param expected to be of type RooAbsReal.";
107 coutE(InputArguments) << errorMsg.str() << std::endl;
108 throw std::runtime_error(errorMsg.str().c_str());
109 }
110 if (spans.find(param) == spans.end()) {
111 _params.add(*param);
112 }
113 }
114 _gradientVarBuffer.resize(_params.size());
115
116 if (head) {
117 _nodeOutputSizes = RooFit::Detail::BatchModeDataHelpers::determineOutputSizes(
118 *head, [&spans](RooFit::Detail::DataKey key) -> int {
119 auto found = spans.find(key);
120 return found != spans.end() ? found->second.size() : -1;
121 });
122 }
123}
124
125std::string RooFuncWrapper::declareFunction(std::string const &funcBody)
126{
127 static int iFuncWrapper = 0;
128 auto funcName = "roo_func_wrapper_" + std::to_string(iFuncWrapper++);
129
130 // Declare the function
131 std::stringstream bodyWithSigStrm;
132 bodyWithSigStrm << "double " << funcName << "(double* params, double const* obs, double const* xlArr) {\n"
133 << funcBody << "\n}";
134 if (!declareToInterpreter(bodyWithSigStrm.str())) {
135 std::stringstream errorMsg;
136 errorMsg << "Function " << funcName << " could not be compiled. See above for details.";
137 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
138 throw std::runtime_error(errorMsg.str().c_str());
139 }
140 return funcName;
141}
142
143void RooFuncWrapper::createGradient()
144{
145 std::string gradName = _funcName + "_grad_0";
146 std::string requestName = _funcName + "_req";
147
148 // Calculate gradient
149 declareToInterpreter("#include <Math/CladDerivator.h>\n");
150 // disable clang-format for making the following code unreadable.
151 // clang-format off
152 std::stringstream requestFuncStrm;
153 requestFuncStrm << "#pragma clad ON\n"
154 "void " << requestName << "() {\n"
155 " clad::gradient(" << _funcName << ", \"params\");\n"
156 "}\n"
157 "#pragma clad OFF";
158 // clang-format on
159 if (!declareToInterpreter(requestFuncStrm.str())) {
160 std::stringstream errorMsg;
161 errorMsg << "Function " << GetName() << " could not be differentiated. See above for details.";
162 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
163 throw std::runtime_error(errorMsg.str().c_str());
164 }
165
166 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine((gradName + ";").c_str()));
167 _hasGradient = true;
168}
169
170void RooFuncWrapper::gradient(double *out) const
171{
172 updateGradientVarBuffer();
173 std::fill(out, out + _params.size(), 0.0);
174
175 _grad(_gradientVarBuffer.data(), _observables.data(), _xlArr.data(), out);
176}
177
178void RooFuncWrapper::updateGradientVarBuffer() const
179{
180 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(),
181 [](RooAbsArg *obj) { return static_cast<RooAbsReal *>(obj)->getVal(); });
182}
183
184double RooFuncWrapper::evaluate() const
185{
186 if (_useEvaluator)
187 return _absReal->getVal();
188 updateGradientVarBuffer();
189
190 return _func(_gradientVarBuffer.data(), _observables.data(), _xlArr.data());
191}
192
193void RooFuncWrapper::gradient(const double *x, double *g) const
194{
195 std::fill(g, g + _params.size(), 0.0);
196
197 _grad(const_cast<double *>(x), _observables.data(), _xlArr.data(), g);
198}
199
200std::string RooFuncWrapper::buildCode(RooAbsReal const &head)
201{
202 RooFit::Detail::CodeSquashContext ctx(_nodeOutputSizes, _xlArr, *this);
203
204 // First update the result variable of params in the compute graph to in[<position>].
205 int idx = 0;
206 for (RooAbsArg *param : _params) {
207 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
208 idx++;
209 }
210
211 for (auto const &item : _obsInfos) {
212 const char *name = item.first->GetName();
213 // If the observable is scalar, set name to the start idx. else, store
214 // the start idx and later set the the name to obs[start_idx + curr_idx],
215 // here curr_idx is defined by a loop producing parent node.
216 if (item.second.size == 1) {
217 ctx.addResult(name, "obs[" + std::to_string(item.second.idx) + "]");
218 } else {
219 ctx.addResult(name, "obs");
220 ctx.addVecObs(name, item.second.idx);
221 }
222 }
223
224 return ctx.assembleCode(ctx.getResult(head));
225}
226
227/// @brief Declare code to the interpreter and keep track of all declared code in this RooFuncWrapper.
228bool RooFuncWrapper::declareToInterpreter(std::string const &code)
229{
230 _allCode << code << std::endl;
231 return gInterpreter->Declare(code.c_str());
232}
233
234/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
235void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
236{
237 std::ofstream outFile;
238 outFile.open(filename + ".C");
239 outFile << "#include <RooFit/Detail/MathFuncs.h>" << std::endl;
240 outFile << std::endl;
241 outFile << _allCode.str();
242 outFile << std::endl;
243
244 updateGradientVarBuffer();
245
246 auto writeVector = [&](std::string const &name, std::span<const double> vec) {
247 outFile << "std::vector<double> " << name << " = {";
248 for (std::size_t i = 0; i < vec.size(); ++i) {
249 if (i % 10 == 0)
250 outFile << "\n ";
251 outFile << vec[i];
252 if (i < vec.size() - 1)
253 outFile << ", ";
254 }
255 outFile << "\n};\n";
256 };
257
258 outFile << "// clang-format off\n" << std::endl;
259 writeVector("parametersVec", _gradientVarBuffer);
260 outFile << std::endl;
261 writeVector("observablesVec", _observables);
262 outFile << std::endl;
263 writeVector("auxConstantsVec", _xlArr);
264 outFile << std::endl;
265 outFile << "// clang-format on\n" << std::endl;
266
267 outFile << R"(
268// To run as a ROOT macro
269void )" << filename
270 << R"(()
271{
272 std::vector<double> gradientVec(parametersVec.size());
273
274 )" << _funcName
275 << R"((parametersVec.data(), observablesVec.data(), auxConstantsVec.data());
276 )" << _funcName
277 << R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(), gradientVec.data());
278}
279)";
280}
281
282} // namespace Experimental
283
284} // namespace RooFit
#define g(i)
Definition RSha256.hxx:105
RooAbsReal * _func
Pointer to original input function.
#define oocoutE(o, a)
#define coutE(a)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:77
RooFit::OwningPtr< RooArgSet > getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:59
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
A class to maintain the context for squashing of RooFit models into code.
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
double(*)(double *, double const *, double const *) Func
std::unique_ptr< RooAbsReal > _absReal
std::string buildCode(RooAbsReal const &head)
void loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
std::map< RooFit::Detail::DataKey, ObsInfo > _obsInfos
void(*)(double *, double const *, double const *, double *) Grad
std::string declareFunction(std::string const &funcBody)
RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data=nullptr, RooSimultaneous const *simPdf=nullptr, bool useEvaluator=false)
bool declareToInterpreter(std::string const &code)
Declare code to the interpreter and keep track of all declared code in this RooFuncWrapper.
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition JSONIO.h:26
@ InputArguments