Logo ROOT  
Reference Guide
BatchModeHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2021
5 *
6 * Copyright (c) 2021, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14
15#include <RooAbsData.h>
16#include <RooAbsPdf.h>
17#include <RooAddition.h>
18#include <RooBatchCompute.h>
19#include <RooBinSamplingPdf.h>
20#include <RooConstraintSum.h>
21#include <RooDataSet.h>
22#include <RooFitDriver.h>
23#include <RooNLLVarNew.h>
24#include <RooRealVar.h>
25#include <RooSimultaneous.h>
26#include <RooFitDriver.h>
27
28#include <ROOT/StringUtils.hxx>
29
30#include <string>
31
34
35namespace {
36
37std::unique_ptr<RooAbsArg> prepareSimultaneousModelForBatchMode(RooSimultaneous &simPdf, RooArgSet &observables,
38 bool isExtended, std::string const &rangeName)
39{
40 // Prepare the NLLTerms for each component
41 RooArgList nllTerms;
42 RooArgSet newObservables;
43 for (auto const &catItem : simPdf.indexCat()) {
44 auto const &catName = catItem.first;
45 if (RooAbsPdf *pdf = simPdf.getPdf(catName.c_str())) {
46 auto nllName = std::string("nll_") + pdf->GetName();
47 auto *nll = new RooNLLVarNew(nllName.c_str(), nllName.c_str(), *pdf, observables, isExtended, rangeName);
48 // Rename the observables and weights
49 newObservables.add(nll->prefixObservableAndWeightNames(std::string("_") + catName + "_"));
50 nllTerms.add(*nll);
51 }
52 }
53
54 observables.clear();
55 observables.add(newObservables);
56
57 // Time to sum the NLLs
58 return std::make_unique<RooAddition>("mynll", "mynll", nllTerms, true);
59}
60
61} // namespace
62
63std::unique_ptr<RooAbsReal>
64RooFit::BatchModeHelpers::createNLL(RooAbsPdf &pdf, RooAbsData &data, std::unique_ptr<RooAbsReal> &&constraints,
65 std::string const &rangeName, std::string const &addCoefRangeName,
66 RooArgSet const &projDeps, bool isExtended, double integrateOverBinsPrecision,
68{
69 std::unique_ptr<RooFitDriver> driver;
70
71 RooArgSet observables;
72 pdf.getObservables(data.get(), observables);
73 observables.remove(projDeps, true, true);
74
75 oocxcoutI(&pdf, Fitting) << "RooAbsPdf::fitTo(" << pdf.GetName()
76 << ") fixing normalization set for coefficient determination to observables in data"
77 << "\n";
78 pdf.fixAddCoefNormalization(observables, false);
79 if (!addCoefRangeName.empty()) {
80 oocxcoutI(&pdf, Fitting) << "RooAbsPdf::fitTo(" << pdf.GetName()
81 << ") fixing interpretation of coefficients of any component to range "
82 << addCoefRangeName << "\n";
83 pdf.fixAddCoefRange(addCoefRangeName.c_str(), false);
84 }
85
86 // Deal with the IntegrateBins argument
87 RooArgList binSamplingPdfs;
88 std::unique_ptr<RooAbsPdf> wrappedPdf;
89 wrappedPdf = RooBinSamplingPdf::create(pdf, data, integrateOverBinsPrecision);
90 RooAbsPdf &finalPdf = wrappedPdf ? *wrappedPdf : pdf;
91 if (wrappedPdf) {
92 binSamplingPdfs.addOwned(std::move(wrappedPdf));
93 }
94 // Done dealing with the IntegrateBins option
95
96 RooArgList nllTerms;
97
98 if (auto simPdf = dynamic_cast<RooSimultaneous *>(&finalPdf)) {
99 auto *simPdfClone = static_cast<RooSimultaneous *>(simPdf->cloneTree());
100 simPdfClone->wrapPdfsInBinSamplingPdfs(data, integrateOverBinsPrecision);
101 // Warning! This mutates "observables"
102 nllTerms.addOwned(prepareSimultaneousModelForBatchMode(*simPdfClone, observables, isExtended, rangeName));
103 } else {
104 nllTerms.addOwned(
105 std::make_unique<RooNLLVarNew>("RooNLLVarNew", "RooNLLVarNew", finalPdf, observables, isExtended, rangeName));
106 }
107 if (constraints) {
108 nllTerms.addOwned(std::move(constraints));
109 }
110
111 std::string nllName = std::string("nll_") + pdf.GetName() + "_" + data.GetName();
112 auto nll = std::make_unique<RooAddition>(nllName.c_str(), nllName.c_str(), nllTerms);
113 nll->addOwnedComponents(std::move(binSamplingPdfs));
114 nll->addOwnedComponents(std::move(nllTerms));
115
116 if (auto simPdf = dynamic_cast<RooSimultaneous *>(&finalPdf)) {
117 RooArgSet parameters;
118 pdf.getParameters(data.get(), parameters);
119 nll->recursiveRedirectServers(parameters);
120 driver = std::make_unique<RooFitDriver>(*nll, observables, batchMode);
121 driver->setData(data, rangeName, &simPdf->indexCat());
122 } else {
123 driver = std::make_unique<RooFitDriver>(*nll, observables, batchMode);
124 driver->setData(data, rangeName);
125 }
126
127 // Set the fitrange attribute so that RooPlot can automatically plot the fitting range by default
128 if (!rangeName.empty()) {
129
130 std::string fitrangeValue;
131 auto subranges = ROOT::Split(rangeName, ",");
132 for (auto const &subrange : subranges) {
133 if (subrange.empty())
134 continue;
135 std::string fitrangeValueSubrange = std::string("fit_") + nll->GetName();
136 if (subranges.size() > 1) {
137 fitrangeValueSubrange += "_" + subrange;
138 }
139 fitrangeValue += fitrangeValueSubrange + ",";
140 for (auto *observable : static_range_cast<RooRealVar *>(observables)) {
141 observable->setRange(fitrangeValueSubrange.c_str(), observable->getMin(subrange.c_str()),
142 observable->getMax(subrange.c_str()));
143 }
144 }
145 fitrangeValue = fitrangeValue.substr(0, fitrangeValue.size() - 1);
146 pdf.setStringAttribute("fitrange", fitrangeValue.c_str());
147 }
148
149 auto driverWrapper = makeDriverAbsRealWrapper(std::move(driver), *data.get());
150 driverWrapper->addOwnedComponents(std::move(nll));
151
152 return driverWrapper;
153}
154
156{
157 // We have to exit early if the message stream is not active. Otherwise it's
158 // possible that this funciton skips logging because it thinks it has
159 // already logged, but actually it didn't.
160 if (!RooMsgService::instance().isActive(static_cast<RooAbsArg *>(nullptr), RooFit::Fitting, RooFit::INFO))
161 return;
162
163 // Don't repeat logging architecture info if the batchMode option didn't change
164 {
165 // Second element of pair tracks whether this function has already been called
166 static std::pair<RooFit::BatchModeOption, bool> lastBatchMode;
167 if (lastBatchMode.second && lastBatchMode.first == batchMode)
168 return;
169 lastBatchMode = {batchMode, true};
170 }
171
172 auto log = [](std::string_view message) {
173 oocxcoutI(static_cast<RooAbsArg *>(nullptr), Fitting) << message << std::endl;
174 };
175
177 throw std::runtime_error(std::string("In: ") + __func__ + "(), " + __FILE__ + ":" + __LINE__ +
178 ": Cuda implementation of the computing library is not available\n");
179 }
181 log("using generic CPU library compiled with no vectorizations");
182 } else {
183 log(std::string("using CPU computation library compiled with -m") +
184 RooBatchCompute::dispatchCPU->architectureName());
185 }
186 if (batchMode == RooFit::BatchModeOption::Cuda) {
187 log("using CUDA computation library");
188 }
189}
190
191namespace {
192
193class RooAbsRealWrapper final : public RooAbsReal {
194public:
195 RooAbsRealWrapper() {}
196 RooAbsRealWrapper(RooFitDriver &driver, RooArgSet const &observables, bool ownsDriver)
197 : RooAbsReal{"RooFitDriverWrapper", "RooFitDriverWrapper"}, _driver{&driver}, _ownsDriver{ownsDriver}
198 {
199 _driver->topNode().getParameters(&observables, _parameters, true);
200 }
201
202 RooAbsRealWrapper(const RooAbsRealWrapper &other, const char *name = 0)
203 : RooAbsReal{other, name}, _driver{other._driver}, _parameters{other._parameters}
204 {
205 }
206
207 ~RooAbsRealWrapper() override
208 {
209 if (_ownsDriver)
210 delete _driver;
211 }
212
213 TObject *clone(const char *newname) const override { return new RooAbsRealWrapper(*this, newname); }
214
215 double defaultErrorLevel() const override { return _driver->topNode().defaultErrorLevel(); }
216
217 bool getParameters(const RooArgSet * /*observables*/, RooArgSet &outputSet,
218 bool /*stripDisconnected=true*/) const override
219 {
220 outputSet.add(_parameters);
221 return false;
222 }
223
224 double getValV(const RooArgSet *) const override { return evaluate(); }
225
226 void applyWeightSquared(bool flag) override
227 {
228 const_cast<RooAbsReal &>(_driver->topNode()).applyWeightSquared(flag);
229 }
230
231protected:
232 double evaluate() const override { return _driver ? _driver->getVal() : 0.0; }
233
234private:
235 RooFitDriver *_driver = nullptr;
236 RooArgSet _parameters;
237 bool _ownsDriver;
238};
239
240} // namespace
241
242/// Static method to create a RooAbsRealWrapper that owns a given RooFitDriver
243/// passed by smart pointer.
244std::unique_ptr<RooAbsReal>
245RooFit::BatchModeHelpers::makeDriverAbsRealWrapper(std::unique_ptr<ROOT::Experimental::RooFitDriver> driver,
246 RooArgSet const &observables)
247{
248 return std::unique_ptr<RooAbsReal>{new RooAbsRealWrapper{*driver.release(), observables, true}};
249}
#define oocxcoutI(o, a)
Definition: RooMsgService.h:91
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
char name[80]
Definition: TGX11.cxx:110
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition: RooAbsArg.h:77
virtual RooAbsArg * cloneTree(const char *newname=0) const
Clone tree expression of objects.
Definition: RooAbsArg.cxx:2310
void setStringAttribute(const Text_t *key, const Text_t *value)
Associate string 'value' to this object under key 'key'.
Definition: RooAbsArg.cxx:314
RooArgSet * getObservables(const RooArgSet &set, bool valueOnly=true) const
Given a set of possible observables, return the observables that this PDF depends on.
Definition: RooAbsArg.h:317
RooArgSet * getParameters(const RooAbsData *data, bool stripDisconnected=true) const
Create a list of leaf nodes in the arg tree starting with ourself as top node that don't match any of...
Definition: RooAbsArg.cxx:569
virtual bool addOwned(RooAbsArg &var, bool silent=false)
Add an argument and transfer the ownership to the collection.
virtual bool remove(const RooAbsArg &var, bool silent=false, bool matchByNameOnly=false)
Remove the specified argument from our list.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
RooAbsArg * first() const
void clear()
Clear contents. If the collection is owning, it will also delete the contents.
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition: RooAbsData.h:61
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition: RooAbsReal.h:64
virtual void fixAddCoefNormalization(const RooArgSet &addNormSet=RooArgSet(), bool force=true)
Fix the interpretation of the coefficient of any RooAddPdf component in the expression tree headed by...
virtual void fixAddCoefRange(const char *rangeName=0, bool force=true)
Fix the interpretation of the coefficient of any RooAddPdf component in the expression tree headed by...
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgSet.h:57
static std::unique_ptr< RooAbsPdf > create(RooAbsPdf &pdf, RooAbsData const &data, double precision)
Creates a wrapping RooBinSamplingPdf if appropriate.
static RooMsgService & instance()
Return reference to singleton instance.
RooSimultaneous facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
void wrapPdfsInBinSamplingPdfs(RooAbsData const &data, double precision)
Wraps the components of this RooSimultaneous in RooBinSamplingPdfs.
const RooAbsCategoryLValue & indexCat() const
RooAbsPdf * getPdf(const char *catName) const
Return the p.d.f associated with the given index category name.
const char * GetName() const override
Returns name of object.
Definition: TNamed.h:47
Mother of all ROOT objects.
Definition: TObject.h:37
RVec< PromoteType< T > > log(const RVec< T > &v)
Definition: RVec.hxx:1748
basic_string_view< char > string_view
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.
Definition: StringUtils.cxx:23
R__EXTERN RooBatchComputeInterface * dispatchCUDA
R__EXTERN RooBatchComputeInterface * dispatchCPU
This dispatch pointer points to an implementation of the compute library, provided one has been loade...
std::unique_ptr< RooAbsReal > createNLL(RooAbsPdf &pdf, RooAbsData &data, std::unique_ptr< RooAbsReal > &&constraints, std::string const &rangeName, std::string const &addCoefRangeName, RooArgSet const &projDeps, bool isExtended, double integrateOverBinsPrecision, RooFit::BatchModeOption batchMode)
void logArchitectureInfo(RooFit::BatchModeOption batchMode)
std::unique_ptr< RooAbsReal > makeDriverAbsRealWrapper(std::unique_ptr< ROOT::Experimental::RooFitDriver > driver, RooArgSet const &observables)
Static method to create a RooAbsRealWrapper that owns a given RooFitDriver passed by smart pointer.
BatchModeOption
For setting the batch mode flag with the BatchMode() command argument to RooAbsPdf::fitTo();.
Definition: RooGlobalFunc.h:70
void evaluate(typename Architecture_t::Tensor_t &A, EActivationFunction f)
Apply the given activation function to each value in the given tensor A.
Definition: Functions.h:98