Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BatchModeHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2021
5 *
6 * Copyright (c) 2021, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14#include <RooAbsData.h>
15#include <RooAbsPdf.h>
16#include <RooAddition.h>
17#include <RooBatchCompute.h>
18#include <RooBinSamplingPdf.h>
19#include <RooCategory.h>
20#include <RooConstraintSum.h>
21#include <RooDataSet.h>
22#include "RooFitDriver.h"
23#include "RooNLLVarNew.h"
24#include <RooRealVar.h>
25#include <RooSimultaneous.h>
26
27#include <string>
28
31
32namespace {
33
34std::unique_ptr<RooAbsArg> createSimultaneousNLL(RooSimultaneous const &simPdf, bool isExtended,
35 std::string const &rangeName, RooFit::OffsetMode offset)
36{
37 RooAbsCategoryLValue const &simCat = simPdf.indexCat();
38
39 // Prepare the NLL terms for each component
40 RooArgList nllTerms;
41 for (auto const &catState : simCat) {
42 std::string const &catName = catState.first;
43 RooAbsCategory::value_type catIndex = catState.second;
44
45 // If the channel is not in the selected range of the category variable, we
46 // won't create an NLL this channel.
47 if (!rangeName.empty()) {
48 // Only the RooCategory supports ranges, not the other
49 // RooAbsCategoryLValue-derived classes.
50 auto simCatAsRooCategory = dynamic_cast<RooCategory const *>(&simCat);
51 if (simCatAsRooCategory && !simCatAsRooCategory->isStateInRange(rangeName.c_str(), catIndex)) {
52 continue;
53 }
54 }
55
56 if (RooAbsPdf *pdf = simPdf.getPdf(catName.c_str())) {
57 auto name = std::string("nll_") + pdf->GetName();
58 std::unique_ptr<RooArgSet> observables(
59 static_cast<RooArgSet *>(std::unique_ptr<RooArgSet>(pdf->getVariables())->selectByAttrib("__obs__", true)));
60 auto nll = std::make_unique<RooNLLVarNew>(name.c_str(), name.c_str(), *pdf, *observables, isExtended, offset);
61 // Rename the special variables
62 nll->setPrefix(std::string("_") + catName + "_");
63 nllTerms.addOwned(std::move(nll));
64 }
65 }
66
67 for (auto *nll : static_range_cast<RooNLLVarNew *>(nllTerms)) {
68 nll->setSimCount(nllTerms.size());
69 }
70
71 // Time to sum the NLLs
72 auto nll = std::make_unique<RooAddition>("mynll", "mynll", nllTerms);
73 nll->addOwnedComponents(std::move(nllTerms));
74 return nll;
75}
76
77class RooAbsRealWrapper final : public RooAbsReal {
78public:
79 RooAbsRealWrapper(std::unique_ptr<RooFitDriver> driver, std::string const &rangeName, RooSimultaneous const *simPdf,
80 bool takeGlobalObservablesFromData)
81 : RooAbsReal{"RooFitDriverWrapper", "RooFitDriverWrapper"}, _driver{std::move(driver)},
82 _topNode("topNode", "top node", this, _driver->topNode()), _rangeName{rangeName}, _simPdf{simPdf},
83 _takeGlobalObservablesFromData{takeGlobalObservablesFromData}
84 {
85 }
86
87 RooAbsRealWrapper(const RooAbsRealWrapper &other, const char *name = nullptr)
88 : RooAbsReal{other, name}, _driver{other._driver},
89 _topNode("topNode", this, other._topNode), _data{other._data}, _parameters{other._parameters},
90 _rangeName{other._rangeName}, _simPdf{other._simPdf}, _takeGlobalObservablesFromData{
91 other._takeGlobalObservablesFromData}
92 {
93 }
94
95 TObject *clone(const char *newname) const override { return new RooAbsRealWrapper(*this, newname); }
96
97 double defaultErrorLevel() const override { return _driver->topNode().defaultErrorLevel(); }
98
99 bool getParameters(const RooArgSet *observables, RooArgSet &outputSet, bool /*stripDisconnected*/) const override
100 {
101 outputSet.add(_parameters);
102 if (observables) {
103 outputSet.remove(*observables);
104 }
105 // If we take the global observables as data, we have to return these as
106 // parameters instead of the parameters in the model. Otherwise, the
107 // constant parameters in the fit result that are global observables will
108 // not have the right values.
109 if (_takeGlobalObservablesFromData && _data->getGlobalObservables()) {
110 outputSet.replace(*_data->getGlobalObservables());
111 }
112 return false;
113 }
114
115 bool setData(RooAbsData &data, bool /*cloneData*/) override
116 {
117 _data = &data;
118
119 // Figure out what are the parameters for the current dataset
120 _parameters.clear();
121 RooArgSet params;
122 _driver->topNode().getParameters(_data->get(), params, true);
123 for (RooAbsArg *param : params) {
124 if (!param->getAttribute("__obs__")) {
125 _parameters.add(*param);
126 }
127 }
128
129 _driver->setData(*_data, _rangeName, _simPdf, /*skipZeroWeights=*/true, _takeGlobalObservablesFromData);
130 return true;
131 }
132
133 double getValV(const RooArgSet *) const override { return evaluate(); }
134
135 void applyWeightSquared(bool flag) override
136 {
137 const_cast<RooAbsReal &>(_driver->topNode()).applyWeightSquared(flag);
138 }
139
140 void printMultiline(std::ostream &os, Int_t /*contents*/, bool /*verbose*/ = false,
141 TString /*indent*/ = "") const override
142 {
143 _driver->print(os);
144 }
145
146protected:
147 double evaluate() const override { return _driver ? _driver->getVal() : 0.0; }
148
149private:
150 std::shared_ptr<RooFitDriver> _driver;
151 RooRealProxy _topNode;
152 RooAbsData *_data = nullptr;
153 RooArgSet _parameters;
154 std::string _rangeName;
155 RooSimultaneous const *_simPdf = nullptr;
156 const bool _takeGlobalObservablesFromData;
157};
158
159} // namespace
160
161std::unique_ptr<RooAbsReal>
162RooFit::BatchModeHelpers::createNLL(std::unique_ptr<RooAbsPdf> &&pdf, RooAbsData &data,
163 std::unique_ptr<RooAbsReal> &&constraints, std::string const &rangeName,
164 RooArgSet const &projDeps, bool isExtended, double integrateOverBinsPrecision,
166 bool takeGlobalObservablesFromData)
167{
168 if (constraints) {
169 // Redirect the global observables to the ones from the dataset if applicable.
170 constraints->setData(data, false);
171
172 // The computation graph for the constraints is very small, no need to do
173 // the tracking of clean and dirty nodes here.
174 constraints->setOperMode(RooAbsArg::ADirty);
175 }
176
177 RooArgSet observables;
178 pdf->getObservables(data.get(), observables);
179 observables.remove(projDeps, true, true);
180
181 oocxcoutI(pdf.get(), Fitting) << "RooAbsPdf::fitTo(" << pdf->GetName()
182 << ") fixing normalization set for coefficient determination to observables in data"
183 << "\n";
184 pdf->fixAddCoefNormalization(observables, false);
185
186 // Deal with the IntegrateBins argument
187 RooArgList binSamplingPdfs;
188 std::unique_ptr<RooAbsPdf> wrappedPdf = RooBinSamplingPdf::create(*pdf, data, integrateOverBinsPrecision);
189 RooAbsPdf &finalPdf = wrappedPdf ? *wrappedPdf : *pdf;
190 if (wrappedPdf) {
191 binSamplingPdfs.addOwned(std::move(wrappedPdf));
192 }
193 // Done dealing with the IntegrateBins option
194
195 RooArgList nllTerms;
196
197 auto simPdf = dynamic_cast<RooSimultaneous *>(&finalPdf);
198 if (simPdf) {
199 simPdf->wrapPdfsInBinSamplingPdfs(data, integrateOverBinsPrecision);
200 nllTerms.addOwned(createSimultaneousNLL(*simPdf, isExtended, rangeName, offset));
201 } else {
202 nllTerms.addOwned(
203 std::make_unique<RooNLLVarNew>("RooNLLVarNew", "RooNLLVarNew", finalPdf, observables, isExtended, offset));
204 }
205 if (constraints) {
206 nllTerms.addOwned(std::move(constraints));
207 }
208
209 std::string nllName = std::string("nll_") + pdf->GetName() + "_" + data.GetName();
210 auto nll = std::make_unique<RooAddition>(nllName.c_str(), nllName.c_str(), nllTerms);
211 nll->addOwnedComponents(std::move(binSamplingPdfs));
212 nll->addOwnedComponents(std::move(nllTerms));
213
214 auto driver = std::make_unique<RooFitDriver>(*nll, batchMode);
215
216 auto driverWrapper =
217 std::make_unique<RooAbsRealWrapper>(std::move(driver), rangeName, simPdf, takeGlobalObservablesFromData);
218 driverWrapper->setData(data, false);
219 driverWrapper->addOwnedComponents(std::move(nll));
220 driverWrapper->addOwnedComponents(std::move(pdf));
221
222 return driverWrapper;
223}
224
226{
227 // We have to exit early if the message stream is not active. Otherwise it's
228 // possible that this function skips logging because it thinks it has
229 // already logged, but actually it didn't.
230 if (!RooMsgService::instance().isActive(static_cast<RooAbsArg *>(nullptr), RooFit::Fitting, RooFit::INFO)) {
231 return;
232 }
233
234 // Don't repeat logging architecture info if the batchMode option didn't change
235 {
236 // Second element of pair tracks whether this function has already been called
237 static std::pair<RooFit::BatchModeOption, bool> lastBatchMode;
238 if (lastBatchMode.second && lastBatchMode.first == batchMode)
239 return;
240 lastBatchMode = {batchMode, true};
241 }
242
243 auto log = [](std::string_view message) {
244 oocxcoutI(static_cast<RooAbsArg *>(nullptr), Fitting) << message << std::endl;
245 };
246
248 throw std::runtime_error(std::string("In: ") + __func__ + "(), " + __FILE__ + ":" + __LINE__ +
249 ": Cuda implementation of the computing library is not available\n");
250 }
252 log("using generic CPU library compiled with no vectorizations");
253 } else {
254 log(std::string("using CPU computation library compiled with -m") +
255 RooBatchCompute::dispatchCPU->architectureName());
256 }
257 if (batchMode == RooFit::BatchModeOption::Cuda) {
258 log("using CUDA computation library");
259 }
260}
ROOT::RRangeCast< T, false, Range_t > static_range_cast(Range_t &&coll)
#define oocxcoutI(o, a)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
char name[80]
Definition TGX11.cxx:110
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition RooAbsArg.h:78
RooAbsCategoryLValue is the common abstract base class for objects that represent a discrete value th...
virtual bool remove(const RooAbsArg &var, bool silent=false, bool matchByNameOnly=false)
Remove the specified argument from our list.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
Storage_t::size_type size() const
RooAbsArg * first() const
virtual bool replace(const RooAbsArg &var1, const RooAbsArg &var2)
Replace var1 with var2 and return true for success.
virtual bool addOwned(RooAbsArg &var, bool silent=false)
Add an argument and transfer the ownership to the collection.
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:58
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition RooAbsReal.h:61
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:55
static std::unique_ptr< RooAbsPdf > create(RooAbsPdf &pdf, RooAbsData const &data, double precision)
Creates a wrapping RooBinSamplingPdf if appropriate.
RooCategory is an object to represent discrete states.
Definition RooCategory.h:28
static RooMsgService & instance()
Return reference to singleton instance.
RooSimultaneous facilitates simultaneous fitting of multiple PDFs to subsets of a given dataset.
RooAbsPdf * getPdf(RooStringView catName) const
Return the p.d.f associated with the given index category name.
void wrapPdfsInBinSamplingPdfs(RooAbsData const &data, double precision)
Wraps the components of this RooSimultaneous in RooBinSamplingPdfs.
const RooAbsCategoryLValue & indexCat() const
Mother of all ROOT objects.
Definition TObject.h:41
Basic string class.
Definition TString.h:139
R__EXTERN RooBatchComputeInterface * dispatchCUDA
R__EXTERN RooBatchComputeInterface * dispatchCPU
This dispatch pointer points to an implementation of the compute library, provided one has been loade...
std::unique_ptr< RooAbsReal > createNLL(std::unique_ptr< RooAbsPdf > &&pdf, RooAbsData &data, std::unique_ptr< RooAbsReal > &&constraints, std::string const &rangeName, RooArgSet const &projDeps, bool isExtended, double integrateOverBinsPrecision, RooFit::BatchModeOption batchMode, RooFit::OffsetMode offset, bool takeGlobalObservablesFromData)
void logArchitectureInfo(RooFit::BatchModeOption batchMode)
OffsetMode
For setting the offset mode with the Offset() command argument to RooAbsPdf::fitTo()
BatchModeOption
For setting the batch mode flag with the BatchMode() command argument to RooAbsPdf::fitTo()
void evaluate(typename Architecture_t::Tensor_t &A, EActivationFunction f)
Apply the given activation function to each value in the given tensor A.
Definition Functions.h:98