Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RooEvaluatorWrapper.cxx
Go to the documentation of this file.
1/// \cond ROOFIT_INTERNAL
2
3/*
4 * Project: RooFit
5 * Authors:
6 * Jonas Rembser, CERN 2023
7 *
8 * Copyright (c) 2023, CERN
9 *
10 * Redistribution and use in source and binary forms,
11 * with or without modification, are permitted according to the terms
12 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
13 */
14
15/**
16\internal
17\file RooEvaluatorWrapper.cxx
18\class RooEvaluatorWrapper
19\ingroup Roofitcore
20
21Wraps a RooFit::Evaluator that evaluates a RooAbsReal back into a RooAbsReal.
22**/
23
24#include "RooEvaluatorWrapper.h"
25
26#include <RooAbsData.h>
27#include <RooAbsPdf.h>
28#include <RooMsgService.h>
29#include <RooRealVar.h>
30#include <RooSimultaneous.h>
31
33
34#include <TInterpreter.h>
35
36#include <fstream>
37
38namespace RooFit::Experimental {
39
40RooEvaluatorWrapper::RooEvaluatorWrapper(RooAbsReal &topNode, RooAbsData *data, bool useGPU,
41 std::string const &rangeName, RooAbsPdf const *pdf,
43 : RooAbsReal{"RooEvaluatorWrapper", "RooEvaluatorWrapper"},
44 _evaluator{std::make_unique<RooFit::Evaluator>(topNode, useGPU)},
45 _topNode("topNode", "top node", this, topNode, false, false),
46 _data{data},
47 _paramSet("paramSet", "Set of parameters", this),
48 _rangeName{rangeName},
49 _pdf{pdf},
50 _takeGlobalObservablesFromData{takeGlobalObservablesFromData}
51{
52 if (data) {
53 setData(*data, false);
54 }
55 _paramSet.add(_evaluator->getParameters());
56 for (auto const &item : _dataSpans) {
57 _paramSet.remove(*_paramSet.find(item.first->GetName()));
58 }
59}
60
61RooEvaluatorWrapper::RooEvaluatorWrapper(const RooEvaluatorWrapper &other, const char *name)
63 _evaluator{other._evaluator},
64 _topNode("topNode", this, other._topNode),
65 _data{other._data},
66 _paramSet("paramSet", "Set of parameters", this),
67 _rangeName{other._rangeName},
68 _pdf{other._pdf},
69 _takeGlobalObservablesFromData{other._takeGlobalObservablesFromData},
71{
72 _paramSet.add(other._paramSet);
73}
74
75RooEvaluatorWrapper::~RooEvaluatorWrapper() = default;
76
77bool RooEvaluatorWrapper::getParameters(const RooArgSet *observables, RooArgSet &outputSet,
78 bool stripDisconnected) const
79{
80 outputSet.add(_evaluator->getParameters());
81 if (observables) {
82 outputSet.remove(*observables, /*silent*/ false, /*matchByNameOnly*/ true);
83 }
84 // Exclude the data variables from the parameters which are not global observables
85 for (auto const &item : _dataSpans) {
86 if (_data->getGlobalObservables() && _data->getGlobalObservables()->find(item.first->GetName())) {
87 continue;
88 }
89 RooAbsArg *found = outputSet.find(item.first->GetName());
90 if (found) {
91 outputSet.remove(*found);
92 }
93 }
94 // If we take the global observables as data, we have to return these as
95 // parameters instead of the parameters in the model. Otherwise, the
96 // constant parameters in the fit result that are global observables will
97 // not have the right values.
98 if (_takeGlobalObservablesFromData && _data->getGlobalObservables()) {
99 outputSet.replace(*_data->getGlobalObservables());
100 }
101
102 // The disconnected parameters are stripped away in
103 // RooAbsArg::getParametersHook(), that is only called in the original
104 // RooAbsArg::getParameters() implementation. So he have to call it to
105 // identify disconnected parameters to remove.
106 if (stripDisconnected) {
108 _topNode->getParameters(observables, paramsStripped, true);
110 for (RooAbsArg *param : outputSet) {
111 if (!paramsStripped.find(param->GetName())) {
112 toRemove.add(*param);
113 }
114 }
115 outputSet.remove(toRemove, /*silent*/ false, /*matchByNameOnly*/ true);
116 }
117
118 return false;
119}
120
121bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
122{
123 // To make things easiear for RooFit, we only support resetting with
124 // datasets that have the same structure, e.g. the same columns and global
125 // observables. This is anyway the usecase: resetting same-structured data
126 // when iterating over toys.
127 constexpr auto errMsg = "Error in RooAbsReal::setData(): only resetting with same-structured data is supported.";
128
129 _data = &data;
130 bool isInitializing = _paramSet.empty();
131 const std::size_t oldSize = _dataSpans.size();
132
133 std::stack<std::vector<double>>{}.swap(_vectorBuffers);
134 bool skipZeroWeights = !_pdf || !_pdf->getAttribute("BinnedLikelihoodActive");
135 _dataSpans =
136 RooFit::BatchModeDataHelpers::getDataSpans(*_data, _rangeName, dynamic_cast<RooSimultaneous const *>(_pdf),
137 skipZeroWeights, _takeGlobalObservablesFromData, _vectorBuffers);
138 if (!isInitializing && _dataSpans.size() != oldSize) {
139 coutE(DataHandling) << errMsg << std::endl;
140 throw std::runtime_error(errMsg);
141 }
142 for (auto const &item : _dataSpans) {
143 const char *name = item.first->GetName();
144 _evaluator->setInput(name, item.second, false);
145 if (_paramSet.find(name)) {
146 coutE(DataHandling) << errMsg << std::endl;
147 throw std::runtime_error(errMsg);
148 }
149 }
150 return true;
151}
152
153/// @brief A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
154/// The parameters can be accessed as params[<relative position of param in paramSet>] in the function body.
155/// The observables can be accessed as obs[i + j], where i represents the observable position and j
156/// represents the data entry.
157class RooFuncWrapper {
158public:
160
161 bool hasGradient() const { return _hasGradient; }
162 void gradient(double *out) const
163 {
165 std::fill(out, out + _params.size(), 0.0);
166
167 _grad(_gradientVarBuffer.data(), _observables.data(), _xlArr.data(), out);
168 }
169
170 void createGradient();
171
172 void writeDebugMacro(std::string const &) const;
173
174 std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }
175
176 double evaluate() const
177 {
179 return _func(_gradientVarBuffer.data(), _observables.data(), _xlArr.data());
180 }
181
182private:
183 void updateGradientVarBuffer() const;
184
185 std::map<RooFit::Detail::DataKey, std::span<const double>>
187
189
190 using Func = double (*)(double *, double const *, double const *);
191 using Grad = void (*)(double *, double const *, double const *, double *);
192
193 struct ObsInfo {
194 ObsInfo(std::size_t i, std::size_t n) : idx{i}, size{n} {}
195 std::size_t idx = 0;
196 std::size_t size = 0;
197 };
198
199 RooArgList _params;
200 std::string _funcName;
201 Func _func;
202 Grad _grad;
203 bool _hasGradient = false;
204 mutable std::vector<double> _gradientVarBuffer;
205 std::vector<double> _observables;
206 std::map<RooFit::Detail::DataKey, ObsInfo> _obsInfos;
207 std::vector<double> _xlArr;
208 std::vector<std::string> _collectedFunctions;
209};
210
211namespace {
212
213void replaceAll(std::string &str, const std::string &from, const std::string &to)
214{
215 if (from.empty())
216 return;
217 size_t start_pos = 0;
218 while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
219 str.replace(start_pos, from.length(), to);
220 start_pos += to.length(); // In case 'to' contains 'from', like replacing 'x' with 'yx'
221 }
222}
223
224} // namespace
225
226RooFuncWrapper::RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf,
227 RooArgSet const &paramSet)
228{
229 // Load the parameters and observables.
231
232 // Set up the code generation context
233 std::map<RooFit::Detail::DataKey, std::size_t> nodeOutputSizes =
234 RooFit::BatchModeDataHelpers::determineOutputSizes(obj, [&spans](RooFit::Detail::DataKey key) -> int {
235 auto found = spans.find(key);
236 return found != spans.end() ? found->second.size() : -1;
237 });
238
240
241 // First update the result variable of params in the compute graph to in[<position>].
242 int idx = 0;
243 for (RooAbsArg *param : _params) {
244 ctx.addResult(param, "params[" + std::to_string(idx) + "]");
245 idx++;
246 }
247
248 for (auto const &item : _obsInfos) {
249 const char *obsName = item.first->GetName();
250 // If the observable is scalar, set name to the start idx. else, store
251 // the start idx and later set the the name to obs[start_idx + curr_idx],
252 // here curr_idx is defined by a loop producing parent node.
253 if (item.second.size == 1) {
254 ctx.addResult(obsName, "obs[" + std::to_string(item.second.idx) + "]");
255 } else {
256 ctx.addResult(obsName, "obs");
257 ctx.addVecObs(obsName, item.second.idx);
258 }
259 }
260
261 gInterpreter->Declare("#pragma cling optimize(2)");
262
263 // Declare the function and create its derivative.
264 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
265 ROOT::Math::Util::TimingScope timingScope(print, "Function JIT time:");
266 _funcName = ctx.buildFunction(obj, nodeOutputSizes);
267 _func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));
268
269 _xlArr = ctx.xlArr();
270 _collectedFunctions = ctx.collectedFunctions();
271}
272
273std::map<RooFit::Detail::DataKey, std::span<const double>>
274RooFuncWrapper::loadParamsAndData(RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
275{
276 // Extract observables
277 std::stack<std::vector<double>> vectorBuffers; // for data loading
278 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
279
280 if (data) {
281 spans = RooFit::BatchModeDataHelpers::getDataSpans(*data, "", simPdf, true, false, vectorBuffers);
282 }
283
284 std::size_t idx = 0;
285 for (auto const &item : spans) {
286 std::size_t n = item.second.size();
287 _obsInfos.emplace(item.first, ObsInfo{idx, n});
288 _observables.reserve(_observables.size() + n);
289 for (std::size_t i = 0; i < n; ++i) {
290 _observables.push_back(item.second[i]);
291 }
292 idx += n;
293 }
294
295 for (auto *param : paramSet) {
296 if (spans.find(param) == spans.end()) {
297 _params.add(*param);
298 }
299 }
300 _gradientVarBuffer.resize(_params.size());
301
302 return spans;
303}
304
305void RooFuncWrapper::createGradient()
306{
307#ifdef ROOFIT_CLAD
308 std::string gradName = _funcName + "_grad_0";
309 std::string requestName = _funcName + "_req";
310
311 // Calculate gradient
312 gInterpreter->Declare("#include <Math/CladDerivator.h>\n");
313 // disable clang-format for making the following code unreadable.
314 // clang-format off
315 std::stringstream requestFuncStrm;
316 requestFuncStrm << "#pragma clad ON\n"
317 "void " << requestName << "() {\n"
318 " clad::gradient(" << _funcName << ", \"params\");\n"
319 "}\n"
320 "#pragma clad OFF";
321 // clang-format on
322 auto print = [](std::string const &msg) { oocoutI(nullptr, Fitting) << msg << std::endl; };
323
324 bool cladSuccess = false;
325 {
326 ROOT::Math::Util::TimingScope timingScope(print, "Gradient generation time:");
327 cladSuccess = !gInterpreter->Declare(requestFuncStrm.str().c_str());
328 }
329 if (cladSuccess) {
330 std::stringstream errorMsg;
331 errorMsg << "Function could not be differentiated. See above for details.";
332 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
333 throw std::runtime_error(errorMsg.str().c_str());
334 }
335
336 // Clad provides different overloads for the gradient, and we need to
337 // resolve to the one that we want. Without the static_cast, getting the
338 // function pointer would be ambiguous.
339 std::stringstream ss;
340 ROOT::Math::Util::TimingScope timingScope(print, "Gradient IR to machine code time:");
341 ss << "static_cast<void (*)(double *, double const *, double const *, double *)>(" << gradName << ");";
342 _grad = reinterpret_cast<Grad>(gInterpreter->ProcessLine(ss.str().c_str()));
343 _hasGradient = true;
344#else
345 _hasGradient = false;
346 std::stringstream errorMsg;
347 errorMsg << "Function could not be differentiated since ROOT was built without Clad support.";
348 oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
349 throw std::runtime_error(errorMsg.str().c_str());
350#endif
351}
352
353void RooFuncWrapper::updateGradientVarBuffer() const
354{
355 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(), [](RooAbsArg *obj) {
356 return obj->isCategory() ? static_cast<RooAbsCategory *>(obj)->getCurrentIndex()
357 : static_cast<RooAbsReal *>(obj)->getVal();
358 });
359}
360
361/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
362void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
363{
364 std::stringstream allCode;
365 std::set<std::string> seenFunctions;
366
367 // Remove duplicated declared functions
368 for (std::string const &name : _collectedFunctions) {
369 if (seenFunctions.count(name) > 0) {
370 continue;
371 }
372 seenFunctions.insert(name);
373 std::unique_ptr<TInterpreterValue> v = gInterpreter->MakeInterpreterValue();
374 gInterpreter->Evaluate(name.c_str(), *v);
375 std::string s = v->ToString();
376 for (int i = 0; i < 2; ++i) {
377 s = s.erase(0, s.find("\n") + 1);
378 }
379 allCode << s << std::endl;
380 }
381
382 std::ofstream outFile;
383 outFile.open(filename + ".C");
384 outFile << R"(//auto-generated test macro
385#include <RooFit/Detail/MathFuncs.h>
386#include <Math/CladDerivator.h>
387
388#pragma cling optimize(2)
389)" << allCode.str()
390 << R"(
391#pragma clad ON
392void gradient_request() {
393 clad::gradient()"
394 << _funcName << R"(, "params");
395}
396#pragma clad OFF
397)";
398
400
401 auto writeVector = [&](std::string const &name, std::span<const double> vec) {
402 std::stringstream decl;
403 decl << "std::vector<double> " << name << " = {";
404 for (std::size_t i = 0; i < vec.size(); ++i) {
405 if (i % 10 == 0)
406 decl << "\n ";
407 decl << vec[i];
408 if (i < vec.size() - 1)
409 decl << ", ";
410 }
411 decl << "\n};\n";
412
413 std::string declStr = decl.str();
414
415 replaceAll(declStr, "inf", "std::numeric_limits<double>::infinity()");
416 replaceAll(declStr, "nan", "NAN");
417
418 outFile << declStr;
419 };
420
421 outFile << "// clang-format off\n" << std::endl;
422 writeVector("parametersVec", _gradientVarBuffer);
423 outFile << std::endl;
424 writeVector("observablesVec", _observables);
425 outFile << std::endl;
426 writeVector("auxConstantsVec", _xlArr);
427 outFile << std::endl;
428 outFile << "// clang-format on\n" << std::endl;
429
430 outFile << R"(
431// To run as a ROOT macro
432void )" << filename
433 << R"(()
434{
435 std::vector<double> gradientVec(parametersVec.size());
436
437 auto func = [&](std::span<double> params) {
438 return )"
439 << _funcName << R"((params.data(), observablesVec.data(), auxConstantsVec.data());
440 };
441 auto grad = [&](std::span<double> params, std::span<double> out) {
442 return )"
443 << _funcName << R"(_grad_0(parametersVec.data(), observablesVec.data(), auxConstantsVec.data(),
444 out.data());
445 };
446
447 grad(parametersVec, gradientVec);
448
449 auto numDiff = [&](int i) {
450 const double eps = 1e-6;
451 std::vector<double> p{parametersVec};
452 p[i] = parametersVec[i] - eps;
453 double funcValDown = func(p);
454 p[i] = parametersVec[i] + eps;
455 double funcValUp = func(p);
456 return (funcValUp - funcValDown) / (2 * eps);
457 };
458
459 for (std::size_t i = 0; i < parametersVec.size(); ++i) {
460 std::cout << i << ":" << std::endl;
461 std::cout << " numr : " << numDiff(i) << std::endl;
462 std::cout << " clad : " << gradientVec[i] << std::endl;
463 }
464}
465)";
466}
467
468double RooEvaluatorWrapper::evaluate() const
469{
471 return _funcWrapper->evaluate();
472
473 if (!_evaluator)
474 return 0.0;
475
476 _evaluator->setOffsetMode(hideOffset() ? RooFit::EvalContext::OffsetMode::WithoutOffset
477 : RooFit::EvalContext::OffsetMode::WithOffset);
478
479 return _evaluator->run()[0];
480}
481
482void RooEvaluatorWrapper::createFuncWrapper()
483{
484 // Get the parameters.
486 this->getParameters(_data ? _data->get() : nullptr, paramSet, /*sripDisconnectedParams=*/false);
487
489 std::make_unique<RooFuncWrapper>(*_topNode, _data, dynamic_cast<RooSimultaneous const *>(_pdf), paramSet);
490}
491
492void RooEvaluatorWrapper::generateGradient()
493{
494 if (!_funcWrapper)
496 _funcWrapper->createGradient();
497}
498
499void RooEvaluatorWrapper::setUseGeneratedFunctionCode(bool flag)
500{
504}
505
506void RooEvaluatorWrapper::gradient(double *out) const
507{
508 _funcWrapper->gradient(out);
509}
510
511bool RooEvaluatorWrapper::hasGradient() const
512{
513 if (!_funcWrapper)
514 return false;
515 return _funcWrapper->hasGradient();
516}
517
518void RooEvaluatorWrapper::writeDebugMacro(std::string const &filename) const
519{
520 if (_funcWrapper)
521 return _funcWrapper->writeDebugMacro(filename);
522}
523
524} // namespace RooFit::Experimental
525
526/// \endcond
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
#define oocoutE(o, a)
#define oocoutI(o, a)
#define coutE(a)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
char name[80]
Definition TGX11.cxx:110
#define gInterpreter
const_iterator begin() const
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
Definition RooAbsArg.h:76
Abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:57
Abstract interface for all probability density functions.
Definition RooAbsPdf.h:32
Abstract base class for objects that represent a real value and implements functionality common to al...
Definition RooAbsReal.h:63
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:24
A class to maintain the context for squashing of RooFit models into code.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
const Int_t n
Definition legend1.C:16
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
Definition CodegenImpl.h:67
void evaluate(typename Architecture_t::Tensor_t &A, EActivationFunction f)
Apply the given activation function to each value in the given tensor A.
Definition Functions.h:98