Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BatchModeDataHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14
15#include <RooAbsCategory.h>
16#include <RooAbsData.h>
17#include <RooNLLVarNew.h>
18
19#include <ROOT/StringUtils.hxx>
20
21#include <numeric>
22
23namespace {
24
25void splitByCategory(std::map<RooFit::Detail::DataKey, RooSpan<const double>> &dataSpans,
26 RooAbsCategory const &category, std::stack<std::vector<double>> &buffers)
27{
28 std::stack<std::vector<double>> oldBuffers;
29 std::swap(buffers, oldBuffers);
30
31 auto catVals = dataSpans.at(category.namePtr());
32
33 std::map<RooFit::Detail::DataKey, RooSpan<const double>> dataMapSplit;
34
35 for (auto const &dataMapItem : dataSpans) {
36
37 auto const &varNamePtr = dataMapItem.first;
38 auto const &xVals = dataMapItem.second;
39
40 if (varNamePtr == category.namePtr())
41 continue;
42
43 std::map<RooAbsCategory::value_type, std::vector<double>> valuesMap;
44
45 if (xVals.size() == 1) {
46 // If the span is of size one, we will replicate it for each category
47 // component instead of splitting is up by category value.
48 for (auto const &catItem : category) {
49 valuesMap[catItem.second].push_back(xVals[0]);
50 }
51 } else {
52 for (std::size_t i = 0; i < xVals.size(); ++i) {
53 valuesMap[catVals[i]].push_back(xVals[i]);
54 }
55 }
56
57 for (auto const &item : valuesMap) {
58 RooAbsCategory::value_type index = item.first;
59 auto variableName = std::string("_") + category.lookupName(index) + "_" + varNamePtr->GetName();
60 auto variableNamePtr = RooNameReg::instance().constPtr(variableName.c_str());
61
62 buffers.emplace(std::move(item.second));
63 auto const &values = buffers.top();
64 dataMapSplit[variableNamePtr] = RooSpan<const double>(values.data(), values.size());
65 }
66 }
67
68 dataSpans = std::move(dataMapSplit);
69}
70
71} // namespace
72
73////////////////////////////////////////////////////////////////////////////////
74/// Extract all content from a RooFit datasets as a map of spans.
75/// Spans with the weights and squared weights will be also stored in the map,
76/// keyed with the names `_weight` and the `_weight_sumW2`. If the dataset is
77/// unweighted, these weight spans will only contain the single value `1.0`.
78/// Entries with zero weight will be skipped.
79///
80/// \return A `std::map` with spans keyed to name pointers.
81/// \param[in] data The input dataset.
82/// \param[in] rangeName Select only entries from the data in a given range
83/// (empty string for no range).
84/// \param[in] indexCat If not `nullptr`, each span is spit up by this category,
85/// with the new names prefixed by the category component name
86/// surrounded by underscores. For example, if you have a category
87/// with `signal` and `control` samples, the span for a variable `x`
88/// will be split in two spans `_signal_x` and `_control_x`.
89/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
90/// be used as memory for the data if the memory in the dataset
91/// object can't be used directly (e.g. because you used the range
92/// selection or the splitting by categories).
93/// \param[in] skipZeroWeights Skip entries with zero weight when filling the
94/// data spans. Be very careful with enabling it, because the user
95/// might not expect that the batch results are not aligned with the
96/// original dataset anymore!
97std::map<RooFit::Detail::DataKey, RooSpan<const double>>
98RooFit::BatchModeDataHelpers::getDataSpans(RooAbsData const &data, std::string_view rangeName,
99 RooAbsCategory const *indexCat, std::stack<std::vector<double>> &buffers,
100 bool skipZeroWeights)
101{
102 std::map<RooFit::Detail::DataKey, RooSpan<const double>> dataSpans; // output variable
103
104 std::size_t nEvents = static_cast<size_t>(data.numEntries());
105
106 // We also want to support empty datasets: in this case the
107 // RooFitDriver::Dataset is not filled with anything.
108 if (nEvents == 0)
109 return dataSpans;
110
111 if (!buffers.empty()) {
112 throw std::invalid_argument("The buffers container must be empty when passed to getDataSpans()!");
113 }
114
115 auto &nameReg = RooNameReg::instance();
116
117 auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
118 auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);
119
120 std::vector<bool> hasZeroWeight;
121 hasZeroWeight.resize(nEvents);
122 std::size_t nNonZeroWeight = 0;
123
124 // Add weights to the datamap. They should have the names expected by the
125 // RooNLLVarNew. We also add the sumW2 weights here under a different name,
126 // so we can apply the sumW2 correction by easily swapping the spans.
127 {
128 buffers.emplace();
129 auto &buffer = buffers.top();
130 buffers.emplace();
131 auto &bufferSumW2 = buffers.top();
132 if (weight.empty()) {
133 // If the dataset has no weight, we fill the data spans with a scalar
134 // unity weight so we don't need to check for the existance of weights
135 // later in the likelihood.
136 buffer.push_back(1.0);
137 bufferSumW2.push_back(1.0);
138 weight = RooSpan<const double>(buffer.data(), 1);
139 weightSumW2 = RooSpan<const double>(bufferSumW2.data(), 1);
140 nNonZeroWeight = nEvents;
141 } else {
142 buffer.reserve(nEvents);
143 bufferSumW2.reserve(nEvents);
144 for (std::size_t i = 0; i < nEvents; ++i) {
145 if (!skipZeroWeights || weight[i] != 0) {
146 buffer.push_back(weight[i]);
147 bufferSumW2.push_back(weightSumW2[i]);
148 ++nNonZeroWeight;
149 } else {
150 hasZeroWeight[i] = true;
151 }
152 }
153 weight = RooSpan<const double>(buffer.data(), nNonZeroWeight);
154 weightSumW2 = RooSpan<const double>(bufferSumW2.data(), nNonZeroWeight);
155 }
156 using namespace ROOT::Experimental;
157 dataSpans[nameReg.constPtr(RooNLLVarNew::weightVarName)] = weight;
158 dataSpans[nameReg.constPtr(RooNLLVarNew::weightVarNameSumW2)] = weightSumW2;
159 }
160
161 // Get the real-valued batches and cast the also to double branches to put in
162 // the data map
163 for (auto const &item : data.getBatches(0, nEvents)) {
164
165 const TNamed *namePtr = nameReg.constPtr(item.first->GetName());
166 RooSpan<const double> span{item.second};
167
168 buffers.emplace();
169 auto &buffer = buffers.top();
170 buffer.reserve(nNonZeroWeight);
171
172 for (std::size_t i = 0; i < nEvents; ++i) {
173 if (!hasZeroWeight[i]) {
174 buffer.push_back(span[i]);
175 }
176 }
177 dataSpans[namePtr] = RooSpan<const double>(buffer.data(), buffer.size());
178 }
179
180 // Get the category batches and cast the also to double branches to put in
181 // the data map
182 for (auto const &item : data.getCategoryBatches(0, nEvents)) {
183
184 const TNamed *namePtr = nameReg.constPtr(item.first->GetName());
186
187 buffers.emplace();
188 auto &buffer = buffers.top();
189 buffer.reserve(nNonZeroWeight);
190
191 for (std::size_t i = 0; i < nEvents; ++i) {
192 if (!hasZeroWeight[i]) {
193 buffer.push_back(static_cast<double>(intSpan[i]));
194 }
195 }
196 dataSpans[namePtr] = RooSpan<const double>(buffer.data(), buffer.size());
197 }
198
199 nEvents = nNonZeroWeight;
200
201 // Now we have do do the range selection
202 if (!rangeName.empty()) {
203 // figure out which events are in the range
204 std::vector<bool> isInRange(nEvents, false);
205 for (auto const &range : ROOT::Split(rangeName, ",")) {
206 std::vector<bool> isInSubRange(nEvents, true);
207 for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
208 // If the observables is not real-valued, it will not be considered for the range selection
209 if (!observable)
210 continue;
211 observable->inRange({dataSpans.at(observable->namePtr()).data(), nEvents}, range, isInSubRange);
212 }
213 for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
214 isInRange[i] = isInRange[i] | isInSubRange[i];
215 }
216 }
217
218 // reset the number of events
219 nEvents = std::accumulate(isInRange.begin(), isInRange.end(), 0);
220
221 // do the data reduction in the data map
222 for (auto const &item : dataSpans) {
223 auto const &allValues = item.second;
224 if (allValues.size() == 1) {
225 continue;
226 }
227 buffers.emplace(nEvents);
228 double *buffer = buffers.top().data();
229 std::size_t j = 0;
230 for (std::size_t i = 0; i < isInRange.size(); ++i) {
231 if (isInRange[i]) {
232 buffer[j] = allValues[i];
233 ++j;
234 }
235 }
236 dataSpans[item.first] = RooSpan<const double>{buffer, nEvents};
237 }
238 }
239
240 if (indexCat) {
241 splitByCategory(dataSpans, *indexCat, buffers);
242 }
243
244 return dataSpans;
245}
const TNamed * namePtr() const
De-duplicated pointer to this object's name.
Definition RooAbsArg.h:577
RooAbsCategory is the base class for objects that represent a discrete value with a finite number of ...
const std::string & lookupName(value_type index) const
Get the name corresponding to the given index.
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition RooAbsData.h:82
virtual const RooArgSet * get() const
Definition RooAbsData.h:128
CategorySpans getCategoryBatches(std::size_t first=0, std::size_t len=std::numeric_limits< std::size_t >::max()) const
RealSpans getBatches(std::size_t first=0, std::size_t len=std::numeric_limits< std::size_t >::max()) const
Write information to retrieve data columns into evalData.spans.
virtual RooSpan< const double > getWeightBatch(std::size_t first, std::size_t len, bool sumW2=false) const =0
Return event weights of all events in range [first, first+len).
virtual Int_t numEntries() const
Return number of entries in dataset, i.e., count unweighted entries.
const TNamed * constPtr(const char *stringPtr)
Return a unique TNamed pointer for given C++ string.
static RooNameReg & instance()
Return reference to singleton instance.
A simple container to hold a batch of data values.
Definition RooSpan.h:34
The TNamed class is the base class for all named ROOT classes.
Definition TNamed.h:29
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.
std::map< RooFit::Detail::DataKey, RooSpan< const double > > getDataSpans(RooAbsData const &data, std::string_view rangeName, RooAbsCategory const *indexCat, std::stack< std::vector< double > > &buffers, bool skipZeroWeights)
Extract all content from a RooFit datasets as a map of spans.