Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
NormalizationHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14
15#include <TClass.h>
16#include <RooAbsCachedPdf.h>
17#include <RooAbsPdf.h>
18#include <RooAbsReal.h>
19#include <RooAddition.h>
20#include <RooConstraintSum.h>
21#include <RooProdPdf.h>
22
23#include "RooNormalizedPdf.h"
24
25#include <TClass.h>
26
27namespace {
28
30using ServerLists = std::map<DataKey, std::vector<DataKey>>;
31
32class GraphChecker {
33public:
34 GraphChecker(RooAbsArg const &topNode)
35 {
36
37 // To track the RooProdPdfs to figure out which ones are responsible for constraints.
38 std::vector<RooAbsArg *> prodPdfs;
39
40 // Get the list of servers for each node by data key.
41 {
42 RooArgList nodes;
43 topNode.treeNodeServerList(&nodes, nullptr, true, true, false, true);
44 RooArgSet nodesSet{nodes};
45 for (RooAbsArg *node : nodesSet) {
46 if (dynamic_cast<RooProdPdf *>(node)) {
47 prodPdfs.push_back(node);
48 }
49 _serverLists[node];
50 bool isConstraintSum = dynamic_cast<RooConstraintSum const *>(node);
51 for (RooAbsArg *server : node->servers()) {
52 _serverLists[node].push_back(server);
53 if (isConstraintSum)
54 _constraints.insert(server);
55 }
56 }
57 }
58 for (auto &item : _serverLists) {
59 auto &l = item.second;
60 std::sort(l.begin(), l.end());
61 l.erase(std::unique(l.begin(), l.end()), l.end());
62 }
63
64 // Loop over the RooProdPdfs to figure out which ones are responsible for constraints.
65 for (auto *prodPdf : static_range_cast<RooProdPdf *>(prodPdfs)) {
66 std::size_t actualPdfIdx = 0;
67 std::size_t nNonConstraint = 0;
68 for (std::size_t i = 0; i < prodPdf->pdfList().size(); ++i) {
69 RooAbsArg &pdf = prodPdf->pdfList()[i];
70
71 // Heuristic for HistFactory models to find also the constraints
72 // that were not extracted for the RooConstraint sum, e.g. because
73 // they were constant. TODO: fix RooProdPdf such that is also
74 // extracts constraints for which the parameters is set constant.
75 bool isProbablyConstraint = std::string(pdf.GetName()).find("onstrain") != std::string::npos;
76
77 if (_constraints.find(&pdf) == _constraints.end() && !isProbablyConstraint) {
78 actualPdfIdx = i;
79 ++nNonConstraint;
80 }
81 }
82 if (nNonConstraint != prodPdf->pdfList().size()) {
83 if (nNonConstraint != 1) {
84 throw std::runtime_error("A RooProdPdf that multiplies a pdf with constraints should contain only one "
85 "pdf that is not a constraint!");
86 }
87 _prodPdfsWithConstraints[prodPdf] = actualPdfIdx;
88 }
89 }
90 }
91
92 bool dependsOn(DataKey arg, DataKey testArg)
93 {
94
95 std::pair<DataKey, DataKey> p{arg, testArg};
96
97 auto found = _results.find(p);
98 if (found != _results.end())
99 return found->second;
100
101 if (arg == testArg)
102 return true;
103
104 auto const &serverList = _serverLists.at(arg);
105
106 // Next test direct dependence
107 auto foundServer = std::find(serverList.begin(), serverList.end(), testArg);
108 if (foundServer != serverList.end()) {
109 _results.emplace(p, true);
110 return true;
111 }
112
113 // If not, recurse
114 for (auto const &server : serverList) {
115 bool t = dependsOn(server, testArg);
116 _results.emplace(std::pair<DataKey, DataKey>{server, testArg}, t);
117 if (t) {
118 return true;
119 }
120 }
121
122 _results.emplace(p, false);
123 return false;
124 }
125
126 bool isConstraint(DataKey key) const
127 {
128 auto found = _constraints.find(key);
129 return found != _constraints.end();
130 }
131
132 std::unordered_map<RooAbsArg *, std::size_t> const &prodPdfsWithConstraints() const
133 {
134 return _prodPdfsWithConstraints;
135 }
136
137private:
138 std::unordered_set<DataKey> _constraints;
139 std::unordered_map<RooAbsArg *, std::size_t> _prodPdfsWithConstraints;
140 ServerLists _serverLists;
141 std::map<std::pair<DataKey, DataKey>, bool> _results;
142};
143
144void treeNodeServerListAndNormSets(const RooAbsArg &arg, RooAbsCollection &list, RooArgSet const &normSet,
145 std::unordered_map<DataKey, RooArgSet *> &normSets, GraphChecker const &checker)
146{
147 if (normSets.find(&arg) != normSets.end())
148 return;
149
150 list.add(arg, true);
151
152 // normalization sets only need to be added for pdfs
153 if (dynamic_cast<RooAbsPdf const *>(&arg)) {
154 normSets.insert({&arg, new RooArgSet{normSet}});
155 }
156
157 // Recurse if current node is derived
158 if (arg.isDerived() && !arg.isFundamental()) {
159 for (const auto server : arg.servers()) {
160
161 if (!server->isValueServer(arg)) {
162 continue;
163 }
164
165 // If this is a server that is also serving a RooConstraintSum, it
166 // should be skipped because it is not evaluated by this client (e.g.
167 // a RooProdPdf). It was only part of the servers to be extracted for
168 // the constraint sum.
169 if (!dynamic_cast<RooConstraintSum const *>(&arg) && checker.isConstraint(server)) {
170 continue;
171 }
172
173 auto differentSet = arg.fillNormSetForServer(normSet, *server);
174 if (differentSet) {
175 differentSet->sort();
176 }
177
178 auto &serverNormSet = differentSet ? *differentSet : normSet;
179
180 // Make sure that the server is not already part of the computation
181 // graph with a different normalization set.
182 auto found = normSets.find(server);
183 if (found != normSets.end()) {
184 if (found->second->size() != serverNormSet.size() || !serverNormSet.hasSameLayout(*found->second)) {
185 std::stringstream ss;
186 ss << server->IsA()->GetName() << "::" << server->GetName()
187 << " is requested to be evaluated with two different normalization sets in the same model!";
188 ss << " This is not supported yet. The conflicting norm sets are:\n RooArgSet";
189 serverNormSet.printValue(ss);
190 ss << " requested by " << arg.IsA()->GetName() << "::" << arg.GetName() << "\n RooArgSet";
191 found->second->printValue(ss);
192 ss << " first requested by other client";
193 auto errMsg = ss.str();
194 oocoutE(server, Minimization) << errMsg << std::endl;
195 throw std::runtime_error(errMsg);
196 }
197 continue;
198 }
199
200 treeNodeServerListAndNormSets(*server, list, serverNormSet, normSets, checker);
201 }
202 }
203}
204
205std::vector<std::unique_ptr<RooAbsArg>> unfoldIntegrals(RooAbsArg const &topNode, RooArgSet const &normSet,
206 std::unordered_map<DataKey, RooArgSet *> &normSets,
207 RooArgSet &replacedArgs)
208{
209 std::vector<std::unique_ptr<RooAbsArg>> newNodes;
210
211 // No normalization set: we don't need to create any integrals
212 if (normSet.empty())
213 return newNodes;
214
215 GraphChecker checker{topNode};
216
217 RooArgSet nodes;
218 // The norm sets are sorted to compare them for equality more easliy
219 RooArgSet normSetSorted{normSet};
220 normSetSorted.sort();
221 treeNodeServerListAndNormSets(topNode, nodes, normSetSorted, normSets, checker);
222
223 // Clean normsets of the variables that the arg does not depend on
224 for (auto &item : normSets) {
225 if (!item.second || item.second->empty())
226 continue;
227 auto actualNormSet = new RooArgSet{};
228 for (auto *narg : *item.second) {
229 if (checker.dependsOn(item.first, narg))
230 // Add the arg from the actual node list in the computation graph.
231 // Like this, we don't accidentally add internal variable clones
232 // that the client args returned. Looking this up is fast because
233 // of the name pointer hash map optimization.
234 actualNormSet->add(*nodes.find(*narg));
235 }
236 delete item.second;
237 item.second = actualNormSet;
238 }
239
240 // Function to `oldArg` with `newArg` in the computation graph.
241 auto replaceArg = [&](RooAbsArg &newArg, RooAbsArg const &oldArg) {
242 const std::string attrib = std::string("ORIGNAME:") + oldArg.GetName();
243
244 newArg.setAttribute(attrib.c_str());
245 newArg.setStringAttribute("_replaced_arg", oldArg.GetName());
246
247 RooArgList newServerList{newArg};
248
249 RooArgList originalClients;
250 for (auto *client : oldArg.clients()) {
251 originalClients.add(*client);
252 }
253 for (auto *client : originalClients) {
254 if (!nodes.containsInstance(*client))
255 continue;
256 if (dynamic_cast<RooAbsCachedPdf *>(client))
257 continue;
258 client->redirectServers(newServerList, false, true);
259 }
260 replacedArgs.add(oldArg);
261
262 newArg.setAttribute(attrib.c_str(), false);
263 };
264
265 // Replaces the RooProdPdfs that were used to wrap constraints with the actual pdf.
266 for (RooAbsArg *node : nodes) {
267 if (auto prodPdf = dynamic_cast<RooProdPdf *>(node)) {
268 auto found = checker.prodPdfsWithConstraints().find(prodPdf);
269 if (found != checker.prodPdfsWithConstraints().end()) {
270 replaceArg(prodPdf->pdfList()[found->second], *prodPdf);
271 }
272 }
273 }
274
275 // Replace all pdfs that need to be normalized with a pdf wrapper that
276 // applies the right normalization.
277 for (RooAbsArg *node : nodes) {
278 if (auto pdf = dynamic_cast<RooAbsPdf *>(node)) {
279 RooArgSet const &currNormSet = *normSets.at(pdf);
280
281 if (currNormSet.empty())
282 continue;
283
284 // The call to getVal() sets up cached states for this normalization
285 // set, which is important in case this pdf is also used by clients
286 // using the getVal() interface (without this, test 28 in stressRooFit
287 // is failing for example).
288 pdf->getVal(currNormSet);
289
290 if (pdf->selfNormalized() && !dynamic_cast<RooAbsCachedPdf *>(pdf))
291 continue;
292
293 auto normalizedPdf = std::make_unique<RooNormalizedPdf>(*pdf, currNormSet);
294
295 replaceArg(*normalizedPdf, *pdf);
296
297 newNodes.emplace_back(std::move(normalizedPdf));
298 }
299 }
300
301 return newNodes;
302}
303
304void foldIntegrals(RooAbsArg const &topNode, RooArgSet &replacedArgs)
305{
306 RooArgSet nodes;
307 topNode.treeNodeServerList(&nodes);
308
309 for (RooAbsArg *normalizedPdf : nodes) {
310
311 if (auto const &replacedArgName = normalizedPdf->getStringAttribute("_replaced_arg")) {
312
313 auto pdf = &replacedArgs[replacedArgName];
314
315 pdf->setAttribute((std::string("ORIGNAME:") + normalizedPdf->GetName()).c_str());
316
317 RooArgList newServerList{*pdf};
318 for (auto *client : normalizedPdf->clients()) {
319 client->redirectServers(newServerList, false, true);
320 }
321
322 pdf->setAttribute((std::string("ORIGNAME:") + normalizedPdf->GetName()).c_str(), false);
323
324 normalizedPdf->setStringAttribute("_replaced_arg", nullptr);
325 }
326 }
327}
328
329} // namespace
330
331/// \class NormalizationIntegralUnfolder
332/// \ingroup Roofitcore
333///
334/// A NormalizationIntegralUnfolder takes the top node of a computation graph
335/// and a normalization set for its constructor. The normalization integrals
336/// for the PDFs in that graph will be created, and placed into the computation
337/// graph itself, rewiring the existing RooAbsArgs. When the unfolder goes out
338/// of scope, all changes to the computation graph will be reverted.
339///
340/// It also performs some other optimizations of the computation graph that are
341/// reverted when the object goes out of scope:
342///
343/// 1. Replacing RooProdPdfs that were used to bring constraints into the
344/// likelihood with the actual pdf that is not a constraint.
345///
346/// Note that for evaluation, the original topNode should not be used anymore,
347/// because if it is a pdf there is now a new normalized pdf wrapping it,
348/// serving as the new top node. This normalized top node can be retreived by
349/// NormalizationIntegralUnfolder::arg().
350
352 : _topNodeWrapper{std::make_unique<RooAddition>("_dummy", "_dummy", RooArgList{topNode})}, _normSetWasEmpty{
353 normSet.empty()}
354{
355 auto ownedArgs = unfoldIntegrals(*_topNodeWrapper, normSet, _normSets, _replacedArgs);
356 for (std::unique_ptr<RooAbsArg> &arg : ownedArgs) {
357 _topNodeWrapper->addOwnedComponents(std::move(arg));
358 }
359 _arg = &static_cast<RooAddition &>(*_topNodeWrapper).list()[0];
360}
361
363{
364 // If there was no normalization set to compile the computation graph for,
365 // we also don't need to fold the integrals back in.
366 if (_normSetWasEmpty)
367 return;
368
369 foldIntegrals(*_topNodeWrapper, _replacedArgs);
370
371 for (auto &item : _normSets) {
372 delete item.second;
373 }
374}
ROOT::RRangeCast< T, false, Range_t > static_range_cast(Range_t &&coll)
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
#define oocoutE(o, a)
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition RooAbsArg.h:69
Bool_t redirectServers(const RooAbsCollection &newServerList, Bool_t mustReplaceAll=kFALSE, Bool_t nameChange=kFALSE, Bool_t isRecursionStep=kFALSE)
Replace all direct servers of this object with the new servers in newServerList.
void setStringAttribute(const Text_t *key, const Text_t *value)
Associate string 'value' to this object under key 'key'.
virtual Bool_t isFundamental() const
Is this object a fundamental type that can be added to a dataset? Fundamental-type subclasses overrid...
Definition RooAbsArg.h:239
void treeNodeServerList(RooAbsCollection *list, const RooAbsArg *arg=0, Bool_t doBranch=kTRUE, Bool_t doLeaf=kTRUE, Bool_t valueOnly=kFALSE, Bool_t recurseNonDerived=kFALSE) const
Fill supplied list with nodes of the arg tree, following all server links, starting with ourself as t...
void setAttribute(const Text_t *name, Bool_t value=kTRUE)
Set (default) or clear a named boolean attribute of this object.
virtual Bool_t isDerived() const
Does value or shape of this arg depend on any other arg?
Definition RooAbsArg.h:89
virtual std::unique_ptr< RooArgSet > fillNormSetForServer(RooArgSet const &normSet, RooAbsArg const &server) const
Fills a RooArgSet to be used as the normalization set for a server, given a normalization set for thi...
RooAbsCachedPdf is the abstract base class for p.d.f.s that need or want to cache their evaluate() ou...
RooAbsCollection is an abstract container object that can hold multiple RooAbsArg objects.
void sort(Bool_t reverse=false)
Sort collection using std::sort and name comparison.
virtual Bool_t add(const RooAbsArg &var, Bool_t silent=kFALSE)
Add the specified argument to list.
RooAbsArg * find(const char *name) const
Find object with given name in list.
RooAddition calculates the sum of a set of RooAbsReal terms, or when constructed with two sets,...
Definition RooAddition.h:27
const RooArgList & list() const
Definition RooAddition.h:43
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition RooArgSet.h:35
Bool_t containsInstance(const RooAbsArg &var) const override
Check if this exact instance is in this collection.
Definition RooArgSet.h:149
RooConstraintSum calculates the sum of the -(log) likelihoods of a set of RooAbsPfs that represent co...
To use as a key type for RooFit data maps and containers.
std::unique_ptr< RooAbsArg > _topNodeWrapper
NormalizationIntegralUnfolder(RooAbsArg const &topNode, RooArgSet const &normSet)
std::unordered_map< RooFit::Detail::DataKey, RooArgSet * > _normSets
RooProdPdf is an efficient implementation of a product of PDFs of the form.
Definition RooProdPdf.h:33
virtual const char * GetName() const
Returns name of object.
Definition TNamed.h:47
auto * l
Definition textangle.C:4