Logo ROOT  
Reference Guide
RooNLLVarNew.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2021
5 * Emmanouil Michalainas, CERN 2021
6 *
7 * Copyright (c) 2021, CERN
8 *
9 * Redistribution and use in source and binary forms,
10 * with or without modification, are permitted according to the terms
11 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
12 */
13
14/**
15\file RooNLLVarNew.cxx
16\class RooNLLVarNew
17\ingroup Roofitcore
18
19This is a simple class designed to produce the nll values needed by the fitter.
20In contrast to the `RooNLLVar` class, any logic except the bare minimum has been
21transfered away to other classes, like the `RooFitDriver`. This class also calls
22functions from `RooBatchCompute` library to provide faster computation times.
23**/
24
25#include <RooNLLVarNew.h>
26
27#include <RooAddition.h>
28#include <RooFormulaVar.h>
29#include <RooNaNPacker.h>
31
32#include <ROOT/StringUtils.hxx>
33
34#include <Math/Util.h>
35
36#include <numeric>
37#include <stdexcept>
38#include <vector>
39
40using namespace ROOT::Experimental;
41
42// Declare constexpr static members to make them available if odr-used in C++14.
43constexpr const char *RooNLLVarNew::weightVarName;
44constexpr const char *RooNLLVarNew::weightVarNameSumW2;
45
46namespace {
47
48std::unique_ptr<RooAbsReal> createRangeNormTerm(RooAbsPdf const &pdf, RooArgSet const &observables,
49 std::string const &baseName, std::string const &rangeNames)
50{
51
52 RooArgSet observablesInPdf;
53 pdf.getObservables(&observables, observablesInPdf);
54
55 std::unique_ptr<RooAbsReal> integral{pdf.createIntegral(observablesInPdf,
56 &observablesInPdf,
57 pdf.getIntegratorConfig(), rangeNames.c_str())};
58 auto out = std::make_unique<RooFormulaVar>((baseName + "_correctionTerm").c_str(), "(log(x[0]))", RooArgList(*integral));
59 out->addOwnedComponents(std::move(integral));
60 return out;
61}
62
63template <class Input>
64double kahanSum(Input const &input)
65{
66 return ROOT::Math::KahanSum<double, 4u>::Accumulate(input.begin(), input.end()).Sum();
67}
68
69} // namespace
70
71/** Construct a RooNLLVarNew
72\param name the name
73\param title the title
74\param pdf The pdf for which the nll is computed for
75\param observables The observabes of the pdf
76\param isExtended Set to true if this is an extended fit
77\param rangeName the range name
78**/
79RooNLLVarNew::RooNLLVarNew(const char *name, const char *title, RooAbsPdf &pdf, RooArgSet const &observables,
80 bool isExtended, std::string const &rangeName)
81 : RooAbsReal(name, title), _pdf{"pdf", "pdf", this, pdf}, _observables{observables}, _isExtended{isExtended}
82{
83 if (!rangeName.empty()) {
84 auto term = createRangeNormTerm(pdf, observables, pdf.GetName(), rangeName);
85 _rangeNormTerm = std::make_unique<RooTemplateProxy<RooAbsReal>>("_rangeNormTerm", "_rangeNormTerm", this, *term);
86 this->addOwnedComponents(std::move(term));
87 }
88}
89
91 : RooAbsReal(other, name), _pdf{"pdf", this, other._pdf}, _observables{other._observables}
92{
93 if (other._rangeNormTerm)
94 _rangeNormTerm = std::make_unique<RooTemplateProxy<RooAbsReal>>("_rangeNormTerm", this, *other._rangeNormTerm);
95}
96
97/** Compute multiple negative logs of propabilities
98
99\param output An array of doubles where the computation results will be stored
100\param nOut not used
101\note nEvents is the number of events to be processed (the dataMap size)
102\param dataMap A map containing spans with the input data for the computation
103**/
104void RooNLLVarNew::computeBatch(cudaStream_t * /*stream*/, double *output, size_t /*nOut*/,
105 RooFit::Detail::DataMap const& dataMap) const
106{
107 std::size_t nEvents = dataMap.at(_pdf).size();
108 auto probas = dataMap.at(_pdf);
109
110 auto logProbasBuffer = ROOT::Experimental::Detail::makeCpuBuffer(nEvents);
111 RooSpan<double> logProbas{logProbasBuffer->cpuWritePtr(), nEvents};
112 (*_pdf).getLogProbabilities(probas, logProbas.data());
113
114 auto &nameReg = RooNameReg::instance();
115 auto weights = dataMap.at(nameReg.constPtr((_prefix + weightVarName).c_str()));
116 auto weightsSumW2 = dataMap.at(nameReg.constPtr((_prefix + weightVarNameSumW2).c_str()));
117 auto weightSpan = _weightSquared ? weightsSumW2 : weights;
118
119 if ((_isExtended || _rangeNormTerm) && _sumWeight == 0.0) {
120 _sumWeight = weights.size() == 1 ? weights[0] * nEvents : kahanSum(weights);
121 }
123 _sumWeight2 = weights.size() == 1 ? weightsSumW2[0] * nEvents : kahanSum(weightsSumW2);
124 }
125 if (_rangeNormTerm) {
126 auto rangeNormTermSpan = dataMap.at(*_rangeNormTerm);
127 if (rangeNormTermSpan.size() == 1) {
128 _sumCorrectionTerm = (_weightSquared ? _sumWeight2 : _sumWeight) * rangeNormTermSpan[0];
129 } else {
130 if (weightSpan.size() == 1) {
131 _sumCorrectionTerm = weightSpan[0] * kahanSum(rangeNormTermSpan);
132 } else {
133 // We don't need to use the library for now because the weights and
134 // correction term integrals are always in the CPU map.
135 _sumCorrectionTerm = 0.0;
136 for (std::size_t i = 0; i < nEvents; ++i) {
137 _sumCorrectionTerm += weightSpan[i] * rangeNormTermSpan[i];
138 }
139 }
140 }
141 }
142
143 std::vector<double> nlls(nEvents);
144 nlls.reserve(nEvents);
145 double nll = 0.0;
146
147 if (weightSpan.size() > 1) {
148 for (std::size_t i = 0; i < nEvents; ++i) {
149 // Explicitely add zero if zero weight to get rid of eventual NaNs in
150 // logProbas that have no weight anyway.
151 nlls.push_back(weightSpan[i] == 0.0 ? 0.0 : -logProbas[i] * weightSpan[i]);
152 }
153 nll = kahanSum(nlls);
154 } else {
155 for (auto const &p : logProbas) {
156 nlls.push_back(-p);
157 }
158 nll = weightSpan[0] * kahanSum(nlls);
159 }
160
161 if (std::isnan(nll)) {
162 // Special handling of evaluation errors.
163 // We can recover if the bin/event that results in NaN has a weight of zero:
164 RooNaNPacker nanPacker;
165 for (std::size_t i = 0; i < probas.size(); ++i) {
166 if (weightSpan.size() > 1) {
167 if (std::isnan(logProbas[i]) && weightSpan[i] != 0.0) {
168 nanPacker.accumulate(logProbas[i]);
169 }
170 }
171 if (std::isnan(logProbas[i])) {
172 nanPacker.accumulate(logProbas[i]);
173 }
174 }
175
176 // Some events with evaluation errors. Return "badness" of errors.
177 if (nanPacker.getPayload() > 0.) {
178 nll = nanPacker.getNaNWithPayload();
179 }
180 }
181
182 if (_isExtended) {
183 assert(_sumWeight != 0.0);
185 }
186 if (_rangeNormTerm) {
187 nll += _sumCorrectionTerm;
188 }
189 output[0] = nll;
190
191 // Since the output of this node is always of size one, it is possible that it is
192 // evaluated in scalar mode. We need to set the cached value and clear
193 // the dirty flag.
194 const_cast<RooNLLVarNew *>(this)->setCachedValue(nll);
195 const_cast<RooNLLVarNew *>(this)->clearValueDirty();
196}
197
199{
200 return _value;
201}
202
203void RooNLLVarNew::getParametersHook(const RooArgSet * /*nset*/, RooArgSet *params, bool /*stripDisconnected*/) const
204{
205 // strip away the observables and weights
206 params->remove(_observables, true, true);
207}
208
209////////////////////////////////////////////////////////////////////////////////
210/// Replaces all observables and the weight variable of this NLL with clones
211/// that only differ by a prefix added to the names. Used for simultaneous fits.
212/// \return A RooArgSet with the new observable args.
213/// \param[in] prefix The prefix to add to the observables and weight names.
215{
216 _prefix = prefix;
217
218 RooArgSet obsSet{_observables};
219 RooArgSet obsClones;
220 obsSet.snapshot(obsClones);
221 for (RooAbsArg *arg : obsClones) {
222 arg->setAttribute((std::string("ORIGNAME:") + arg->GetName()).c_str());
223 arg->SetName((prefix + arg->GetName()).c_str());
224 }
225 recursiveRedirectServers(obsClones, false, true);
226
227 RooArgSet newObservables{obsClones};
228
229 setObservables(obsClones);
230 addOwnedComponents(std::move(obsClones));
231
232 return newObservables;
233}
234
235////////////////////////////////////////////////////////////////////////////////
236/// Toggles the weight square correction.
238{
239 _weightSquared = flag;
240}
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
char name[80]
Definition: TGX11.cxx:110
std::unique_ptr< RooTemplateProxy< RooAbsReal > > _rangeNormTerm
Definition: RooNLLVarNew.h:67
void setObservables(RooArgSet const &observables)
Definition: RooNLLVarNew.h:53
void computeBatch(cudaStream_t *, double *output, size_t nOut, RooFit::Detail::DataMap const &) const override
Compute multiple negative logs of propabilities.
void applyWeightSquared(bool flag) override
Toggles the weight square correction.
void getParametersHook(const RooArgSet *nset, RooArgSet *list, bool stripDisconnected) const override
RooTemplateProxy< RooAbsPdf > _pdf
Definition: RooNLLVarNew.h:59
static constexpr const char * weightVarName
Definition: RooNLLVarNew.h:30
RooArgSet prefixObservableAndWeightNames(std::string const &prefix)
Replaces all observables and the weight variable of this NLL with clones that only differ by a prefix...
static constexpr const char * weightVarNameSumW2
Definition: RooNLLVarNew.h:31
double evaluate() const override
Evaluate this PDF / function / constant. Needs to be overridden by all derived classes.
static KahanSum< T, N > Accumulate(Iterator begin, Iterator end, T initialValue=T{})
Iterate over a range and return an instance of a KahanSum.
Definition: Util.h:211
RooAbsArg is the common abstract base class for objects that represent a value and a "shape" in RooFi...
Definition: RooAbsArg.h:77
bool recursiveRedirectServers(const RooAbsCollection &newServerList, bool mustReplaceAll=false, bool nameChange=false, bool recurseInNewSet=true)
Recursively replace all servers with the new servers in newSet.
Definition: RooAbsArg.cxx:1197
RooArgSet * getObservables(const RooArgSet &set, bool valueOnly=true) const
Given a set of possible observables, return the observables that this PDF depends on.
Definition: RooAbsArg.h:317
bool addOwnedComponents(const RooAbsCollection &comps)
Take ownership of the contents of 'comps'.
Definition: RooAbsArg.cxx:2275
void clearValueDirty() const
Definition: RooAbsArg.h:622
virtual bool remove(const RooAbsArg &var, bool silent=false, bool matchByNameOnly=false)
Remove the specified argument from our list.
double extendedTerm(double sumEntries, const RooArgSet *nset, double sumEntriesW2=0.0) const
Return the extended likelihood term ( ) of this PDF for the given number of observed events.
Definition: RooAbsPdf.cxx:802
RooAbsReal is the common abstract base class for objects that represent a real value and implements f...
Definition: RooAbsReal.h:64
RooAbsReal * createIntegral(const RooArgSet &iset, const RooCmdArg &arg1, const RooCmdArg &arg2=RooCmdArg::none(), const RooCmdArg &arg3=RooCmdArg::none(), const RooCmdArg &arg4=RooCmdArg::none(), const RooCmdArg &arg5=RooCmdArg::none(), const RooCmdArg &arg6=RooCmdArg::none(), const RooCmdArg &arg7=RooCmdArg::none(), const RooCmdArg &arg8=RooCmdArg::none()) const
Create an object that represents the integral of the function over one or more observables listed in ...
Definition: RooAbsReal.cxx:553
double _value
Cache for current value of object.
Definition: RooAbsReal.h:484
const RooNumIntConfig * getIntegratorConfig() const
Return the numeric integration configuration used for this object.
void setCachedValue(double value, bool notifyClients=true) final
Overwrite the value stored in this object's cache.
Definition: RooAbsReal.h:608
RooArgList is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgList.h:22
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Definition: RooArgSet.h:57
RooArgSet * snapshot(bool deepCopy=true) const
Use RooAbsCollection::snapshot(), but return as RooArgSet.
Definition: RooArgSet.h:180
static RooNameReg & instance()
Return reference to singleton instance.
Definition: RooNameReg.cxx:50
A simple container to hold a batch of data values.
Definition: RooSpan.h:34
const char * GetName() const override
Returns name of object.
Definition: TNamed.h:47
std::unique_ptr< AbsBuffer > makeCpuBuffer(std::size_t size)
Definition: Buffers.cxx:220
std::map< DataKey, RooSpan< const double > > DataMap
Definition: DataMap.h:59
void probas(TString dataset, TString fin="TMVA.root", Bool_t useTMVAStyle=kTRUE)
Little struct that can pack a float into the unused bits of the mantissa of a NaN double.
Definition: RooNaNPacker.h:28
float getPayload() const
Retrieve packed float.
Definition: RooNaNPacker.h:85
double getNaNWithPayload() const
Retrieve a NaN with the current float payload packed into the mantissa.
Definition: RooNaNPacker.h:90
void accumulate(double val)
Accumulate a packed float from another NaN into this.
Definition: RooNaNPacker.h:57
static void output()