31namespace Experimental {
35 :
RooAbsReal{
name, title}, _params{
"!params",
"List of parameters", this}
38 _absReal = std::make_unique<RooEvaluatorWrapper>(obj,
const_cast<RooAbsData *
>(
data),
false,
"", simPdf,
false);
48 if (!param->isConstant()) {
49 floatingParamSet.
add(*param);
65 _params(
"!params", this, other._params),
66 _funcName(other._funcName),
69 _hasGradient(other._hasGradient),
70 _gradientVarBuffer(other._gradientVarBuffer),
71 _observables(other._observables)
79 std::stack<std::vector<double>> vectorBuffers;
80 std::map<RooFit::Detail::DataKey, std::span<const double>> spans;
83 spans = RooFit::Detail::BatchModeDataHelpers::getDataSpans(*
data,
"", simPdf,
true,
false, vectorBuffers);
87 for (
auto const &item : spans) {
88 std::size_t
n = item.second.size();
91 for (std::size_t i = 0; i <
n; ++i) {
98 for (
auto *param : paramSet) {
100 std::stringstream errorMsg;
101 errorMsg <<
"In creation of function " << GetName()
102 <<
" wrapper: input param expected to be of type RooAbsReal.";
104 throw std::runtime_error(errorMsg.str().c_str());
106 if (spans.find(param) == spans.end()) {
110 _gradientVarBuffer.resize(_params.size());
113 _nodeOutputSizes = RooFit::Detail::BatchModeDataHelpers::determineOutputSizes(
115 auto found = spans.find(key);
116 return found != spans.end() ? found->second.size() : -1;
121std::string RooFuncWrapper::declareFunction(std::string
const &funcBody)
123 static int iFuncWrapper = 0;
124 auto funcName =
"roo_func_wrapper_" + std::to_string(iFuncWrapper++);
129 std::stringstream bodyWithSigStrm;
130 bodyWithSigStrm <<
"double " << funcName <<
"(double* params, double const* obs, double const* xlArr) {\n"
131 << funcBody <<
"\n}";
132 bool comp =
gInterpreter->Declare(bodyWithSigStrm.str().c_str());
134 std::stringstream errorMsg;
135 errorMsg <<
"Function " << funcName <<
" could not be compiled. See above for details.";
137 throw std::runtime_error(errorMsg.str().c_str());
142void RooFuncWrapper::createGradient()
144 std::string gradName = _funcName +
"_grad_0";
145 std::string requestName = _funcName +
"_req";
146 std::string wrapperName = _funcName +
"_derivativeWrapper";
149 gInterpreter->ProcessLine(
"#include <Math/CladDerivator.h>");
152 std::stringstream requestFuncStrm;
153 requestFuncStrm <<
"#pragma clad ON\n"
154 "void " << requestName <<
"() {\n"
155 " clad::gradient(" << _funcName <<
", \"params\");\n"
159 auto comp =
gInterpreter->Declare(requestFuncStrm.str().c_str());
161 std::stringstream errorMsg;
162 errorMsg <<
"Function " << GetName() <<
" could not be differentiated. See above for details.";
164 throw std::runtime_error(errorMsg.str().c_str());
170 std::stringstream dWrapperStrm;
171 dWrapperStrm <<
"void " << wrapperName <<
"(double* params, double const* obs, double const* xlArr, double* out) {\n"
172 " clad::array_ref<double> cladOut(out, " << _params.size() <<
");\n"
173 " " << gradName <<
"(params, obs, xlArr, cladOut);\n"
177 _grad =
reinterpret_cast<Grad>(
gInterpreter->ProcessLine((wrapperName +
";").c_str()));
181void RooFuncWrapper::gradient(
double *out)
const
183 updateGradientVarBuffer();
184 std::fill(out, out + _params.size(), 0.0);
186 _grad(_gradientVarBuffer.data(), _observables.data(), _xlArr.data(), out);
189void RooFuncWrapper::updateGradientVarBuffer()
const
191 std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(),
192 [](
RooAbsArg *obj) { return static_cast<RooAbsReal *>(obj)->getVal(); });
195double RooFuncWrapper::evaluate()
const
198 return _absReal->getVal();
199 updateGradientVarBuffer();
201 return _func(_gradientVarBuffer.data(), _observables.data(), _xlArr.data());
204void RooFuncWrapper::gradient(
const double *
x,
double *
g)
const
206 std::fill(
g,
g + _params.size(), 0.0);
208 _grad(
const_cast<double *
>(
x), _observables.data(), _xlArr.data(),
g);
218 ctx.
addResult(param,
"params[" + std::to_string(idx) +
"]");
222 for (
auto const &item : _obsInfos) {
223 const char *
name = item.first->GetName();
227 if (item.second.size == 1) {
228 ctx.
addResult(
name,
"obs[" + std::to_string(item.second.idx) +
"]");
239void RooFuncWrapper::dumpCode()
245void RooFuncWrapper::dumpGradient()
247 gInterpreter->ProcessLine((_funcName +
"_grad_0").c_str());
RooAbsReal * _func
Pointer to original input function.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Common abstract base class for objects that represent a value and a "shape" in RooFit.
RooFit::OwningPtr< RooArgSet > getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
Abstract base class for binned and unbinned datasets.
Abstract base class for objects that represent a real value and implements functionality common to al...
RooArgSet is a container object that can hold multiple RooAbsArg objects.
A class to maintain the context for squashing of RooFit models into code.
std::string assembleCode(std::string const &returnExpr)
Assemble and return the final code with the return expression and global statements.
void addResult(RooAbsArg const *key, std::string const &value)
A function to save an expression that includes/depends on the result of the input node.
void addVecObs(const char *key, int idx)
Since the squashed code represents all observables as a single flattened array, it is important to ke...
std::string const & getResult(RooAbsArg const &arg)
Gets the result for the given node using the node name.
A wrapper class to store a C++ function of type 'double (*)(double*, double*)'.
double(*)(double *, double const *, double const *) Func
std::unique_ptr< RooAbsReal > _absReal
std::string buildCode(RooAbsReal const &head)
std::vector< double > _observables
void loadParamsAndData(RooAbsArg const *head, RooArgSet const ¶mSet, const RooAbsData *data, RooSimultaneous const *simPdf)
std::map< RooFit::Detail::DataKey, ObsInfo > _obsInfos
void(*)(double *, double const *, double const *, double *) Grad
static std::string declareFunction(std::string const &funcBody)
RooFuncWrapper(const char *name, const char *title, RooAbsReal &obj, const RooAbsData *data=nullptr, RooSimultaneous const *simPdf=nullptr, bool useEvaluator=false)
Facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...