Logo ROOT  
Reference Guide
BatchModeDataHelpers.cxx
Go to the documentation of this file.
1/*
2 * Project: RooFit
3 * Authors:
4 * Jonas Rembser, CERN 2022
5 *
6 * Copyright (c) 2022, CERN
7 *
8 * Redistribution and use in source and binary forms,
9 * with or without modification, are permitted according to the terms
10 * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11 */
12
14#include <RooAbsData.h>
15#include <RooDataHist.h>
16#include "RooNLLVarNew.h"
17
18#include <ROOT/StringUtils.hxx>
19
20#include <numeric>
21
22////////////////////////////////////////////////////////////////////////////////
23/// Extract all content from a RooFit datasets as a map of spans.
24/// Spans with the weights and squared weights will be also stored in the map,
25/// keyed with the names `_weight` and the `_weight_sumW2`. If the dataset is
26/// unweighted, these weight spans will only contain the single value `1.0`.
27/// Entries with zero weight will be skipped. If the input dataset is a
28/// RooDataHist, the output map will also contain an item for the key
29/// `_bin_volume` with the bin volumes.
30///
31/// \return A `std::map` with spans keyed to name pointers.
32/// \param[in] data The input dataset.
33/// \param[in] rangeName Select only entries from the data in a given range
34/// (empty string for no range).
35/// \param[in] prefix A string prefix to use for all key names for the data
36/// map.
37/// \param[in] buffers Pass here an empty stack of `double` vectors, which will
38/// be used as memory for the data if the memory in the dataset
39/// object can't be used directly (e.g. because you used the range
40/// selection or the splitting by categories).
41/// \param[in] skipZeroWeights Skip entries with zero weight when filling the
42/// data spans. Be very careful with enabling it, because the user
43/// might not expect that the batch results are not aligned with the
44/// original dataset anymore!
45std::map<RooFit::Detail::DataKey, RooSpan<const double>>
47 std::string const &prefix, std::stack<std::vector<double>> &buffers,
48 bool skipZeroWeights)
49{
50 std::map<RooFit::Detail::DataKey, RooSpan<const double>> dataSpans; // output variable
51
52 auto &nameReg = RooNameReg::instance();
53
54 auto insert = [&](const char *key, RooSpan<const double> span) {
55 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
56 dataSpans[namePtr] = span;
57 };
58
59 auto retrieve = [&](const char *key) {
60 const TNamed *namePtr = nameReg.constPtr((prefix + key).c_str());
61 return dataSpans.at(namePtr);
62 };
63
64 std::size_t nEvents = static_cast<size_t>(data.numEntries());
65
66 // We also want to support empty datasets: in this case the
67 // RooFitDriver::Dataset is not filled with anything.
68 if (nEvents == 0) {
69 return dataSpans;
70 }
71
72 auto weight = data.getWeightBatch(0, nEvents, /*sumW2=*/false);
73 auto weightSumW2 = data.getWeightBatch(0, nEvents, /*sumW2=*/true);
74
75 std::vector<bool> hasZeroWeight;
76 hasZeroWeight.resize(nEvents);
77 std::size_t nNonZeroWeight = 0;
78
79 // Add weights to the datamap. They should have the names expected by the
80 // RooNLLVarNew. We also add the sumW2 weights here under a different name,
81 // so we can apply the sumW2 correction by easily swapping the spans.
82 {
83 buffers.emplace();
84 auto &buffer = buffers.top();
85 buffers.emplace();
86 auto &bufferSumW2 = buffers.top();
87 if (weight.empty()) {
88 // If the dataset has no weight, we fill the data spans with a scalar
89 // unity weight so we don't need to check for the existance of weights
90 // later in the likelihood.
91 buffer.push_back(1.0);
92 bufferSumW2.push_back(1.0);
93 weight = RooSpan<const double>(buffer.data(), 1);
94 weightSumW2 = RooSpan<const double>(bufferSumW2.data(), 1);
95 nNonZeroWeight = nEvents;
96 } else {
97 buffer.reserve(nEvents);
98 bufferSumW2.reserve(nEvents);
99 for (std::size_t i = 0; i < nEvents; ++i) {
100 if (!skipZeroWeights || weight[i] != 0) {
101 buffer.push_back(weight[i]);
102 bufferSumW2.push_back(weightSumW2[i]);
103 ++nNonZeroWeight;
104 } else {
105 hasZeroWeight[i] = true;
106 }
107 }
108 weight = RooSpan<const double>(buffer.data(), nNonZeroWeight);
109 weightSumW2 = RooSpan<const double>(bufferSumW2.data(), nNonZeroWeight);
110 }
111 using namespace ROOT::Experimental;
112 insert(RooNLLVarNew::weightVarName, weight);
113 insert(RooNLLVarNew::weightVarNameSumW2, weightSumW2);
114 }
115
116 // Add also bin volume information if we are dealing with a RooDataHist
117 if (auto dataHist = dynamic_cast<RooDataHist const *>(&data)) {
118 buffers.emplace();
119 auto &buffer = buffers.top();
120 buffer.reserve(nNonZeroWeight);
121
122 for (std::size_t i = 0; i < nEvents; ++i) {
123 if (!hasZeroWeight[i]) {
124 buffer.push_back(dataHist->binVolume(i));
125 }
126 }
127
128 insert("_bin_volume", {buffer.data(), buffer.size()});
129 }
130
131 // Get the real-valued batches and cast the also to double branches to put in
132 // the data map
133 for (auto const &item : data.getBatches(0, nEvents)) {
134
135 RooSpan<const double> span{item.second};
136
137 buffers.emplace();
138 auto &buffer = buffers.top();
139 buffer.reserve(nNonZeroWeight);
140
141 for (std::size_t i = 0; i < nEvents; ++i) {
142 if (!hasZeroWeight[i]) {
143 buffer.push_back(span[i]);
144 }
145 }
146 insert(item.first->GetName(), {buffer.data(), buffer.size()});
147 }
148
149 // Get the category batches and cast the also to double branches to put in
150 // the data map
151 for (auto const &item : data.getCategoryBatches(0, nEvents)) {
152
154
155 buffers.emplace();
156 auto &buffer = buffers.top();
157 buffer.reserve(nNonZeroWeight);
158
159 for (std::size_t i = 0; i < nEvents; ++i) {
160 if (!hasZeroWeight[i]) {
161 buffer.push_back(static_cast<double>(intSpan[i]));
162 }
163 }
164 insert(item.first->GetName(), {buffer.data(), buffer.size()});
165 }
166
167 nEvents = nNonZeroWeight;
168
169 // Now we have do do the range selection
170 if (!rangeName.empty()) {
171 // figure out which events are in the range
172 std::vector<bool> isInRange(nEvents, false);
173 for (auto const &range : ROOT::Split(rangeName, ",")) {
174 std::vector<bool> isInSubRange(nEvents, true);
175 for (auto *observable : dynamic_range_cast<RooAbsRealLValue *>(*data.get())) {
176 // If the observables is not real-valued, it will not be considered for the range selection
177 if (observable) {
178 observable->inRange({retrieve(observable->GetName()).data(), nEvents}, range, isInSubRange);
179 }
180 }
181 for (std::size_t i = 0; i < isInSubRange.size(); ++i) {
182 isInRange[i] = isInRange[i] || isInSubRange[i];
183 }
184 }
185
186 // reset the number of events
187 nEvents = std::accumulate(isInRange.begin(), isInRange.end(), 0);
188
189 // do the data reduction in the data map
190 for (auto const &item : dataSpans) {
191 auto const &allValues = item.second;
192 if (allValues.size() == 1) {
193 continue;
194 }
195 buffers.emplace(nEvents);
196 double *buffer = buffers.top().data();
197 std::size_t j = 0;
198 for (std::size_t i = 0; i < isInRange.size(); ++i) {
199 if (isInRange[i]) {
200 buffer[j] = allValues[i];
201 ++j;
202 }
203 }
204 dataSpans[item.first] = RooSpan<const double>{buffer, nEvents};
205 }
206 }
207
208 return dataSpans;
209}
static void retrieve(const gsl_integration_workspace *workspace, double *a, double *b, double *r, double *e)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
RooAbsData is the common abstract base class for binned and unbinned datasets.
Definition: RooAbsData.h:61
The RooDataHist is a container class to hold N-dimensional binned data.
Definition: RooDataHist.h:39
static RooNameReg & instance()
Return reference to singleton instance.
Definition: RooNameReg.cxx:50
A simple container to hold a batch of data values.
Definition: RooSpan.h:34
The TNamed class is the base class for all named ROOT classes.
Definition: TNamed.h:29
basic_string_view< char > string_view
std::vector< std::string > Split(std::string_view str, std::string_view delims, bool skipEmpty=false)
Splits a string at each character in delims.
Definition: StringUtils.cxx:23
std::map< RooFit::Detail::DataKey, RooSpan< const double > > getDataSpans(RooAbsData const &data, std::string_view rangeName, std::string const &prefix, std::stack< std::vector< double > > &buffers, bool skipZeroWeights)
Extract all content from a RooFit datasets as a map of spans.