Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ReadSpeed.cxx
Go to the documentation of this file.
1// Author: Enrico Guiraud, David Poulton 2022
2
3/*************************************************************************
4 * Copyright (C) 1995-2022, Rene Brun and Fons Rademakers. *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11#include "ReadSpeed.hxx"
12
13#include <ROOT/TSeq.hxx>
14
15#ifdef R__USE_IMT
17#include <ROOT/TTreeProcessorMT.hxx> // for TTreeProcessorMT::GetTasksPerWorkerHint
18#include <ROOT/RSlotStack.hxx>
19#endif
20
21#include <ROOT/InternalTreeUtils.hxx> // for ROOT::Internal::TreeUtils::GetTopLevelBranchNames
22#include <TBranch.h>
23#include <TStopwatch.h>
24#include <TTree.h>
25
26#include <algorithm>
27#include <cassert>
28#include <cmath> // std::ceil
29#include <memory>
30#include <numeric> // std::accumulate
31#include <stdexcept>
32#include <set>
33#include <iostream>
34
35using namespace ReadSpeed;
36
37std::vector<std::string> ReadSpeed::GetMatchingBranchNames(const std::string &fileName, const std::string &treeName,
38 const std::vector<ReadSpeedRegex> &regexes)
39{
40 const auto f = std::unique_ptr<TFile>(TFile::Open(fileName.c_str(), "READ_WITHOUT_GLOBALREGISTRATION"));
41 if (f == nullptr || f->IsZombie())
42 throw std::runtime_error("Could not open file '" + fileName + '\'');
43 std::unique_ptr<TTree> t(f->Get<TTree>(treeName.c_str()));
44 if (t == nullptr)
45 throw std::runtime_error("Could not retrieve tree '" + treeName + "' from file '" + fileName + '\'');
46
47 const auto unfilteredBranchNames = ROOT::Internal::TreeUtils::GetTopLevelBranchNames(*t);
48 std::set<ReadSpeedRegex> usedRegexes;
49 std::vector<std::string> branchNames;
50
51 auto filterBranchName = [regexes, &usedRegexes](const std::string &bName) {
52 if (regexes.size() == 1 && regexes[0].text == ".*") {
53 usedRegexes.insert(regexes[0]);
54 return true;
55 }
56
57 const auto matchBranch = [&usedRegexes, bName](const ReadSpeedRegex &regex) {
58 bool match = std::regex_match(bName, regex.regex);
59
60 if (match)
61 usedRegexes.insert(regex);
62
63 return match;
64 };
65
66 const auto iterator = std::find_if(regexes.begin(), regexes.end(), matchBranch);
67 return iterator != regexes.end();
68 };
69 std::copy_if(unfilteredBranchNames.begin(), unfilteredBranchNames.end(), std::back_inserter(branchNames),
70 filterBranchName);
71
72 if (branchNames.empty()) {
73 std::cerr << "Provided branch regexes didn't match any branches in tree '" + treeName + "' from file '" +
74 fileName + ".\n";
75 std::terminate();
76 }
77 if (usedRegexes.size() != regexes.size()) {
78 std::string errString = "The following regexes didn't match any branches in tree '" + treeName + "' from file '" +
79 fileName + "', this is probably unintended:\n";
80 for (const auto &regex : regexes) {
81 if (usedRegexes.find(regex) == usedRegexes.end())
82 errString += '\t' + regex.text + '\n';
83 }
84 std::cerr << errString;
85 std::terminate();
86 }
87
88 return branchNames;
89}
90
91std::vector<std::vector<std::string>> GetPerFileBranchNames(const Data &d)
92{
93 auto treeIdx = 0;
94 std::vector<std::vector<std::string>> fileBranchNames;
95
96 std::vector<ReadSpeedRegex> regexes;
97 if (d.fUseRegex)
98 std::transform(d.fBranchNames.begin(), d.fBranchNames.end(), std::back_inserter(regexes), [](std::string text) {
99 return ReadSpeedRegex{text, std::regex(text)};
100 });
101
102 for (const auto &fName : d.fFileNames) {
103 std::vector<std::string> branchNames;
104 if (d.fUseRegex)
105 branchNames = GetMatchingBranchNames(fName, d.fTreeNames[treeIdx], regexes);
106 else
107 branchNames = d.fBranchNames;
108
109 fileBranchNames.push_back(branchNames);
110
111 if (d.fTreeNames.size() > 1)
112 ++treeIdx;
113 }
114
115 return fileBranchNames;
116}
117
118ByteData SumBytes(const std::vector<ByteData> &bytesData) {
119 const auto uncompressedBytes =
120 std::accumulate(bytesData.begin(), bytesData.end(), 0ull,
121 [](ULong64_t sum, const ByteData &o) { return sum + o.fUncompressedBytesRead; });
122 const auto compressedBytes =
123 std::accumulate(bytesData.begin(), bytesData.end(), 0ull,
124 [](ULong64_t sum, const ByteData &o) { return sum + o.fCompressedBytesRead; });
125
126 return {uncompressedBytes, compressedBytes};
127};
128
129// Read branches listed in branchNames in tree treeName in file fileName, return number of uncompressed bytes read.
130ByteData ReadSpeed::ReadTree(TFile *f, const std::string &treeName, const std::vector<std::string> &branchNames,
131 EntryRange range)
132{
133 std::unique_ptr<TTree> t(f->Get<TTree>(treeName.c_str()));
134 if (t == nullptr)
135 throw std::runtime_error("Could not retrieve tree '" + treeName + "' from file '" + f->GetName() + '\'');
136
137 t->SetBranchStatus("*", 0);
138
139 std::vector<TBranch *> branches;
140 for (const auto &bName : branchNames) {
141 auto *b = t->GetBranch(bName.c_str());
142 if (b == nullptr)
143 throw std::runtime_error("Could not retrieve branch '" + bName + "' from tree '" + t->GetName() +
144 "' in file '" + t->GetCurrentFile()->GetName() + '\'');
145
146 b->SetStatus(1);
147 branches.push_back(b);
148 }
149
150 const auto nEntries = t->GetEntries();
151 if (range.fStart == -1ll)
152 range = EntryRange{0ll, nEntries};
153 else if (range.fEnd > nEntries)
154 throw std::runtime_error("Range end (" + std::to_string(range.fEnd) + ") is beyond the end of tree '" +
155 t->GetName() + "' in file '" + t->GetCurrentFile()->GetName() + "' with " +
156 std::to_string(nEntries) + " entries.");
157
158 ULong64_t bytesRead = 0;
159 const ULong64_t fileStartBytes = f->GetBytesRead();
160 for (auto e = range.fStart; e < range.fEnd; ++e)
161 for (auto *b : branches)
162 bytesRead += b->GetEntry(e);
163
164 const ULong64_t fileBytesRead = f->GetBytesRead() - fileStartBytes;
165 return {bytesRead, fileBytesRead};
166}
167
169{
170 auto treeIdx = 0;
171 auto fileIdx = 0;
172 ULong64_t uncompressedBytesRead = 0;
173 ULong64_t compressedBytesRead = 0;
174
175 TStopwatch sw;
176 const auto fileBranchNames = GetPerFileBranchNames(d);
177
178 for (const auto &fileName : d.fFileNames) {
179 auto f = std::unique_ptr<TFile>(TFile::Open(fileName.c_str(), "READ_WITHOUT_GLOBALREGISTRATION"));
180 if (f == nullptr || f->IsZombie())
181 throw std::runtime_error("Could not open file '" + fileName + '\'');
182
183 sw.Start(false);
184
185 const auto byteData = ReadTree(f.get(), d.fTreeNames[treeIdx], fileBranchNames[fileIdx]);
186 uncompressedBytesRead += byteData.fUncompressedBytesRead;
187 compressedBytesRead += byteData.fCompressedBytesRead;
188
189 if (d.fTreeNames.size() > 1)
190 ++treeIdx;
191 ++fileIdx;
192
193 sw.Stop();
194 }
195
196 return {sw.RealTime(), sw.CpuTime(), 0., 0., uncompressedBytesRead, compressedBytesRead, 0};
197}
198
199// Return a vector of EntryRanges per file, i.e. a vector of vectors of EntryRanges with outer size equal to
200// d.fFileNames.
201std::vector<std::vector<EntryRange>> ReadSpeed::GetClusters(const Data &d)
202{
203 const auto nFiles = d.fFileNames.size();
204 std::vector<std::vector<EntryRange>> ranges(nFiles);
205 for (auto fileIdx = 0u; fileIdx < nFiles; ++fileIdx) {
206 const auto &fileName = d.fFileNames[fileIdx];
207 std::unique_ptr<TFile> f(TFile::Open(fileName.c_str(), "READ_WITHOUT_GLOBALREGISTRATION"));
208 if (f == nullptr || f->IsZombie())
209 throw std::runtime_error("There was a problem opening file '" + fileName + '\'');
210 const auto &treeName = d.fTreeNames.size() > 1 ? d.fTreeNames[fileIdx] : d.fTreeNames[0];
211 auto *t = f->Get<TTree>(treeName.c_str()); // TFile owns this TTree
212 if (t == nullptr)
213 throw std::runtime_error("There was a problem retrieving TTree '" + treeName + "' from file '" + fileName +
214 '\'');
215
216 const auto nEntries = t->GetEntries();
217 auto it = t->GetClusterIterator(0);
218 Long64_t start = 0;
219 std::vector<EntryRange> rangesInFile;
220 while ((start = it.Next()) < nEntries)
221 rangesInFile.emplace_back(EntryRange{start, it.GetNextEntry()});
222 ranges[fileIdx] = std::move(rangesInFile);
223 }
224 return ranges;
225}
226
227// Mimic the logic of TTreeProcessorMT::MakeClusters: merge entry ranges together such that we
228// run around TTreeProcessorMT::GetTasksPerWorkerHint tasks per worker thread.
229// TODO it would be better to expose TTreeProcessorMT's actual logic and call the exact same method from here
230std::vector<std::vector<EntryRange>>
231ReadSpeed::MergeClusters(std::vector<std::vector<EntryRange>> &&clusters, unsigned int maxTasksPerFile)
232{
233 std::vector<std::vector<EntryRange>> mergedClusters(clusters.size());
234
235 auto clustersIt = clusters.begin();
236 auto mergedClustersIt = mergedClusters.begin();
237 for (; clustersIt != clusters.end(); clustersIt++, mergedClustersIt++) {
238 const auto nClustersInThisFile = clustersIt->size();
239 const auto nFolds = nClustersInThisFile / maxTasksPerFile;
240 // If the number of clusters is less than maxTasksPerFile
241 // we take the clusters as they are
242 if (nFolds == 0) {
243 *mergedClustersIt = *clustersIt;
244 continue;
245 }
246 // Otherwise, we have to merge clusters, distributing the reminder evenly
247 // between the first clusters
248 auto nReminderClusters = nClustersInThisFile % maxTasksPerFile;
249 const auto &clustersInThisFile = *clustersIt;
250 for (auto i = 0ULL; i < nClustersInThisFile; ++i) {
251 const auto start = clustersInThisFile[i].fStart;
252 // We lump together at least nFolds clusters, therefore
253 // we need to jump ahead of nFolds-1.
254 i += (nFolds - 1);
255 // We now add a cluster if we have some reminder left
256 if (nReminderClusters > 0) {
257 i += 1U;
258 nReminderClusters--;
259 }
260 const auto end = clustersInThisFile[i].fEnd;
261 mergedClustersIt->emplace_back(EntryRange({start, end}));
262 }
263 assert(nReminderClusters == 0 && "This should never happen, cluster-merging logic is broken.");
264 }
265
266 return mergedClusters;
267}
268
269Result ReadSpeed::EvalThroughputMT(const Data &d, unsigned nThreads)
270{
271#ifdef R__USE_IMT
272 ROOT::TThreadExecutor pool(nThreads);
273 const auto actualThreads = ROOT::GetThreadPoolSize();
274 if (actualThreads != nThreads)
275 std::cerr << "Running with " << actualThreads << " threads even though " << nThreads << " were requested.\n";
276
277 TStopwatch clsw;
278 clsw.Start();
279 const unsigned int maxTasksPerFile =
280 std::ceil(float(ROOT::TTreeProcessorMT::GetTasksPerWorkerHint() * actualThreads) / float(d.fFileNames.size()));
281
282 const auto rangesPerFile = MergeClusters(GetClusters(d), maxTasksPerFile);
283 clsw.Stop();
284
285 const size_t nranges =
286 std::accumulate(rangesPerFile.begin(), rangesPerFile.end(), 0u, [](size_t s, auto &r) { return s + r.size(); });
287 std::cout << "Total number of tasks: " << nranges << '\n';
288
289 const auto fileBranchNames = GetPerFileBranchNames(d);
290
291 ROOT::Internal::RSlotStack slotStack(actualThreads);
292 std::vector<int> lastFileIdxs(actualThreads, -1);
293 std::vector<std::unique_ptr<TFile>> lastTFiles(actualThreads);
294
295 auto processFile = [&](int fileIdx) {
296 const auto &fileName = d.fFileNames[fileIdx];
297 const auto &treeName = d.fTreeNames.size() > 1 ? d.fTreeNames[fileIdx] : d.fTreeNames[0];
298 const auto &branchNames = fileBranchNames[fileIdx];
299
300 auto readRange = [&](const EntryRange &range) -> ByteData {
301 ROOT::Internal::RSlotStackRAII slotRAII(slotStack);
302 auto slotIndex = slotRAII.fSlot;
303 auto &file = lastTFiles[slotIndex];
304 auto &lastIndex = lastFileIdxs[slotIndex];
305
306 if (lastIndex != fileIdx) {
307 file.reset(TFile::Open(fileName.c_str(), "READ_WITHOUT_GLOBALREGISTRATION"));
308 lastIndex = fileIdx;
309 }
310
311 if (file == nullptr || file->IsZombie())
312 throw std::runtime_error("Could not open file '" + fileName + '\'');
313
314 auto result = ReadTree(file.get(), treeName, branchNames, range);
315
316 return result;
317 };
318
319 const auto byteData = pool.MapReduce(readRange, rangesPerFile[fileIdx], SumBytes);
320
321 return byteData;
322 };
323
324 TStopwatch sw;
325 sw.Start();
326 const auto totalByteData = pool.MapReduce(processFile, ROOT::TSeqUL(d.fFileNames.size()), SumBytes);
327 sw.Stop();
328
329 return {sw.RealTime(),
330 sw.CpuTime(),
331 clsw.RealTime(),
332 clsw.CpuTime(),
333 totalByteData.fUncompressedBytesRead,
334 totalByteData.fCompressedBytesRead,
335 actualThreads};
336#else
337 (void)d;
338 (void)nThreads;
339 return {};
340#endif // R__USE_IMT
341}
342
343Result ReadSpeed::EvalThroughput(const Data &d, unsigned nThreads)
344{
345 if (d.fTreeNames.empty()) {
346 std::cerr << "Please provide at least one tree name\n";
347 std::terminate();
348 }
349 if (d.fFileNames.empty()) {
350 std::cerr << "Please provide at least one file name\n";
351 std::terminate();
352 }
353 if (d.fBranchNames.empty()) {
354 std::cerr << "Please provide at least one branch name\n";
355 std::terminate();
356 }
357 if (d.fTreeNames.size() != 1 && d.fTreeNames.size() != d.fFileNames.size()) {
358 std::cerr << "Please provide either one tree name or as many as the file names\n";
359 std::terminate();
360 }
361
362#ifdef R__USE_IMT
363 return nThreads > 0 ? EvalThroughputMT(d, nThreads) : EvalThroughputST(d);
364#else
365 if (nThreads > 0) {
366 std::cerr << nThreads
367 << " threads were requested, but ROOT was built without implicit multi-threading (IMT) support.\n";
368 std::terminate();
369 }
370 return EvalThroughputST(d);
371#endif
372}
#define d(i)
Definition RSha256.hxx:102
#define b(i)
Definition RSha256.hxx:100
#define f(i)
Definition RSha256.hxx:104
#define e(i)
Definition RSha256.hxx:103
std::vector< std::vector< std::string > > GetPerFileBranchNames(const Data &d)
Definition ReadSpeed.cxx:91
ByteData SumBytes(const std::vector< ByteData > &bytesData)
long long Long64_t
Definition RtypesCore.h:80
unsigned long long ULong64_t
Definition RtypesCore.h:81
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char text
A thread-safe stack of N indexes (0 to size - 1).
A pseudo container class which is a generator of indices.
Definition TSeq.hxx:67
This class provides a simple interface to execute the same task multiple times in parallel threads,...
auto MapReduce(F func, unsigned nTimes, R redfunc) -> InvokeResult_t< F >
Execute a function nTimes in parallel (Map) and accumulate the results into a single value (Reduce).
static unsigned int GetTasksPerWorkerHint()
Retrieve the current value for the desired number of tasks per worker.
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4082
Stopwatch class.
Definition TStopwatch.h:28
Double_t RealTime()
Stop the stopwatch (if it is running) and return the realtime (in seconds) passed between the start a...
void Start(Bool_t reset=kTRUE)
Start the stopwatch.
Double_t CpuTime()
Stop the stopwatch (if it is running) and return the cputime (in seconds) passed between the start an...
void Stop()
Stop the stopwatch.
A TTree represents a columnar dataset.
Definition TTree.h:79
std::vector< std::string > GetTopLevelBranchNames(TTree &t)
Get all the top-level branches names, including the ones of the friend trees.
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.
Definition TROOT.cxx:575
Result EvalThroughputST(const Data &d)
std::vector< std::string > GetMatchingBranchNames(const std::string &fileName, const std::string &treeName, const std::vector< ReadSpeedRegex > &regexes)
Definition ReadSpeed.cxx:37
std::vector< std::vector< EntryRange > > GetClusters(const Data &d)
Result EvalThroughputMT(const Data &d, unsigned nThreads)
Result EvalThroughput(const Data &d, unsigned nThreads)
std::vector< std::vector< EntryRange > > MergeClusters(std::vector< std::vector< EntryRange > > &&clusters, unsigned int maxTasksPerFile)
ByteData ReadTree(TFile *file, const std::string &treeName, const std::vector< std::string > &branchNames, EntryRange range={-1, -1})
A RAII object to pop and push slot numbers from a RSlotStack object.
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2345