Logo ROOT   6.14/05
Reference Guide
RArrowDS.cxx
Go to the documentation of this file.
1 // Author: Giulio Eulisse CERN 2/2018
2 
3 /*************************************************************************
4  * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. *
5  * All rights reserved. *
6  * *
7  * For the licensing terms see $ROOTSYS/LICENSE. *
8  * For the list of contributors see $ROOTSYS/README/CREDITS. *
9  *************************************************************************/
10 
11 // clang-format off
12 /** \class ROOT::RDF::RArrowDS
13  \ingroup dataframe
14  \brief RDataFrame data source class to interface with Apache Arrow.
15 
16 The RArrowDS implements a proxy RDataSource to be able to use Apache Arrow
17 tables with RDataFrame.
18 
19 A RDataFrame that adapts an arrow::Table class can be constructed using the factory method
20 ROOT::RDF::MakeArrowDataFrame, which accepts one parameter:
21 1. An arrow::Table smart pointer.
22 
23 The types of the columns are derived from the types in the associated
24 arrow::Schema.
25 
26 */
27 // clang-format on
28 
29 #include <ROOT/RDFUtils.hxx>
30 #include <ROOT/TSeq.hxx>
31 #include <ROOT/RArrowDS.hxx>
32 #include <ROOT/RMakeUnique.hxx>
33 
34 #include <algorithm>
35 #include <sstream>
36 #include <string>
37 
38 #if defined(__GNUC__)
39 #pragma GCC diagnostic push
40 #pragma GCC diagnostic ignored "-Wshadow"
41 #endif
42 #include <arrow/table.h>
43 #if defined(__GNUC__)
44 #pragma GCC diagnostic pop
45 #endif
46 
47 
48 namespace ROOT {
49 namespace Internal {
50 namespace RDF {
51 // Per slot visitor of an Array.
52 class ArrayPtrVisitor : public ::arrow::ArrayVisitor {
53 private:
54  /// The pointer to update.
55  void **fResult;
56  bool fCachedBool{false}; // Booleans need to be unpacked, so we use a cached entry.
57  std::string fCachedString;
58  /// The entry in the array which should be looked up.
59  ULong64_t fCurrentEntry;
60 
61 public:
62  ArrayPtrVisitor(void **result) : fResult{result}, fCurrentEntry{0} {}
63 
64  void SetEntry(ULong64_t entry) { fCurrentEntry = entry; }
65 
66  /// Check if we are asking the same entry as before.
67  virtual arrow::Status Visit(arrow::Int32Array const &array) final
68  {
69  *fResult = (void *)(array.raw_values() + fCurrentEntry);
70  return arrow::Status::OK();
71  }
72 
73  virtual arrow::Status Visit(arrow::Int64Array const &array) final
74  {
75  *fResult = (void *)(array.raw_values() + fCurrentEntry);
76  return arrow::Status::OK();
77  }
78 
79  /// Check if we are asking the same entry as before.
80  virtual arrow::Status Visit(arrow::UInt32Array const &array) final
81  {
82  *fResult = (void *)(array.raw_values() + fCurrentEntry);
83  return arrow::Status::OK();
84  }
85 
86  virtual arrow::Status Visit(arrow::UInt64Array const &array) final
87  {
88  *fResult = (void *)(array.raw_values() + fCurrentEntry);
89  return arrow::Status::OK();
90  }
91 
92  virtual arrow::Status Visit(arrow::FloatArray const &array) final
93  {
94  *fResult = (void *)(array.raw_values() + fCurrentEntry);
95  return arrow::Status::OK();
96  }
97 
98  virtual arrow::Status Visit(arrow::DoubleArray const &array) final
99  {
100  *fResult = (void *)(array.raw_values() + fCurrentEntry);
101  return arrow::Status::OK();
102  }
103 
104  virtual arrow::Status Visit(arrow::BooleanArray const &array) final
105  {
106  fCachedBool = array.Value(fCurrentEntry);
107  *fResult = reinterpret_cast<void *>(&fCachedBool);
108  return arrow::Status::OK();
109  }
110 
111  virtual arrow::Status Visit(arrow::StringArray const &array) final
112  {
113  fCachedString = array.GetString(fCurrentEntry);
114  *fResult = reinterpret_cast<void *>(&fCachedString);
115  return arrow::Status::OK();
116  }
117 
118  using ::arrow::ArrayVisitor::Visit;
119 };
120 
121 /// Helper class which keeps track for each slot where to get the entry.
122 class TValueGetter {
123 private:
124  std::vector<void *> fValuesPtrPerSlot;
125  std::vector<ULong64_t> fLastEntryPerSlot;
126  std::vector<ULong64_t> fLastChunkPerSlot;
127  std::vector<ULong64_t> fFirstEntryPerChunk;
128  std::vector<ArrayPtrVisitor> fArrayVisitorPerSlot;
129  /// Since data can be chunked in different arrays we need to construct an
130  /// index which contains the first element of each chunk, so that we can
131  /// quickly move to the correct chunk.
132  std::vector<ULong64_t> fChunkIndex;
133  arrow::ArrayVector fChunks;
134 
135 public:
136  TValueGetter(size_t slots, arrow::ArrayVector chunks)
137  : fValuesPtrPerSlot(slots, nullptr), fLastEntryPerSlot(slots, 0), fLastChunkPerSlot(slots, 0), fChunks{chunks}
138  {
139  fChunkIndex.reserve(fChunks.size());
140  size_t next = 0;
141  for (auto &chunk : chunks) {
142  fFirstEntryPerChunk.push_back(next);
143  next += chunk->length();
144  fChunkIndex.push_back(next);
145  }
146  for (size_t si = 0, se = fValuesPtrPerSlot.size(); si != se; ++si) {
147  fArrayVisitorPerSlot.push_back(ArrayPtrVisitor{fValuesPtrPerSlot.data() + si});
148  }
149  }
150 
151  /// This returns the ptr to the ptr to actual data.
152  std::vector<void *> SlotPtrs()
153  {
154  std::vector<void *> result;
155  for (size_t i = 0; i < fValuesPtrPerSlot.size(); ++i) {
156  result.push_back(fValuesPtrPerSlot.data() + i);
157  }
158  return result;
159  }
160 
161  // Convenience method to avoid code duplication between
162  // SetEntry and InitSlot
163  void UncachedSlotLookup(unsigned int slot, ULong64_t entry)
164  {
165  // If entry is greater than the previous one,
166  // we can skip all the chunks before the last one we
167  // queried.
168  size_t ci = 0;
169  assert(slot < fLastChunkPerSlot.size());
170  if (fLastEntryPerSlot[slot] < entry) {
171  ci = fLastChunkPerSlot.at(slot);
172  }
173 
174  for (size_t ce = fChunkIndex.size(); ci != ce; ++ci) {
175  if (entry < fChunkIndex[ci]) {
176  assert(slot < fLastChunkPerSlot.size());
177  fLastChunkPerSlot[slot] = ci;
178  break;
179  }
180  }
181 
182  // Update the pointer to the requested entry.
183  // Notice that we need to find the entry
184  auto chunk = fChunks.at(fLastChunkPerSlot[slot]);
185  assert(slot < fArrayVisitorPerSlot.size());
186  fArrayVisitorPerSlot[slot].SetEntry(entry - fFirstEntryPerChunk[fLastChunkPerSlot[slot]]);
187  auto status = chunk->Accept(fArrayVisitorPerSlot.data() + slot);
188  if (!status.ok()) {
189  std::string msg = "Could not get pointer for slot ";
190  msg += std::to_string(slot) + " looking at entry " + std::to_string(entry);
191  throw std::runtime_error(msg);
192  }
193  }
194 
195  /// Set the current entry to be retrieved
196  void SetEntry(unsigned int slot, ULong64_t entry)
197  {
198  // Same entry as before
199  if (fLastEntryPerSlot[slot] == entry) {
200  return;
201  }
202  UncachedSlotLookup(slot, entry);
203  }
204 };
205 
206 } // namespace RDF
207 } // namespace Internal
208 
209 
210 namespace RDF {
211 
212 /// Helper to get the contents of a given column
213 
214 /// Helper to get the human readable name of type
215 class RDFTypeNameGetter : public ::arrow::TypeVisitor {
216 private:
217  std::string fTypeName;
218 
219 public:
220  arrow::Status Visit(const arrow::Int64Type &) override
221  {
222  fTypeName = "Long64_t";
223  return arrow::Status::OK();
224  }
225  arrow::Status Visit(const arrow::Int32Type &) override
226  {
227  fTypeName = "Long_t";
228  return arrow::Status::OK();
229  }
230  arrow::Status Visit(const arrow::UInt64Type &) override
231  {
232  fTypeName = "ULong64_t";
233  return arrow::Status::OK();
234  }
235  arrow::Status Visit(const arrow::UInt32Type &) override
236  {
237  fTypeName = "ULong_t";
238  return arrow::Status::OK();
239  }
240  arrow::Status Visit(const arrow::FloatType &) override
241  {
242  fTypeName = "float";
243  return arrow::Status::OK();
244  }
245  arrow::Status Visit(const arrow::DoubleType &) override
246  {
247  fTypeName = "double";
248  return arrow::Status::OK();
249  }
250  arrow::Status Visit(const arrow::StringType &) override
251  {
252  fTypeName = "string";
253  return arrow::Status::OK();
254  }
255  arrow::Status Visit(const arrow::BooleanType &) override
256  {
257  fTypeName = "bool";
258  return arrow::Status::OK();
259  }
260  std::string result() { return fTypeName; }
261 
262  using ::arrow::TypeVisitor::Visit;
263 };
264 
265 /// Helper to determine if a given Column is a supported type.
266 class VerifyValidColumnType : public ::arrow::TypeVisitor {
267 private:
268 public:
269  virtual arrow::Status Visit(const arrow::Int64Type &) override { return arrow::Status::OK(); }
270  virtual arrow::Status Visit(const arrow::UInt64Type &) override { return arrow::Status::OK(); }
271  virtual arrow::Status Visit(const arrow::Int32Type &) override { return arrow::Status::OK(); }
272  virtual arrow::Status Visit(const arrow::UInt32Type &) override { return arrow::Status::OK(); }
273  virtual arrow::Status Visit(const arrow::FloatType &) override { return arrow::Status::OK(); }
274  virtual arrow::Status Visit(const arrow::DoubleType &) override { return arrow::Status::OK(); }
275  virtual arrow::Status Visit(const arrow::StringType &) override { return arrow::Status::OK(); }
276  virtual arrow::Status Visit(const arrow::BooleanType &) override { return arrow::Status::OK(); }
277 
278  using ::arrow::TypeVisitor::Visit;
279 };
280 
281 
282 
283 
284 ////////////////////////////////////////////////////////////////////////
285 /// Constructor to create an Arrow RDataSource for RDataFrame.
286 /// \param[in] table the arrow Table to observe.
287 /// \param[in] columns the name of the columns to use
288 /// In case columns is empty, we use all the columns found in the table
289 RArrowDS::RArrowDS(std::shared_ptr<arrow::Table> inTable, std::vector<std::string> const &inColumns)
290  : fTable{inTable}, fColumnNames{inColumns}
291 {
292  auto &columnNames = fColumnNames;
293  auto &table = fTable;
294  auto &index = fGetterIndex;
295  // We want to allow people to specify which columns they
296  // need so that we can think of upfront IO optimizations.
297  auto filterWantedColumns = [&columnNames, &table]()
298  {
299  if (columnNames.empty()) {
300  for (auto &field : table->schema()->fields()) {
301  columnNames.push_back(field->name());
302  }
303  }
304  };
305 
306  auto getRecordsFirstColumn = [&columnNames, &table]()
307  {
308  if (columnNames.empty()) {
309  throw std::runtime_error("At least one column required");
310  }
311  const auto name = columnNames.front();
312  const auto columnIdx = table->schema()->GetFieldIndex(name);
313  return table->column(columnIdx)->length();
314  };
315 
316  // All columns are supposed to have the same number of entries.
317  auto verifyColumnSize = [](std::shared_ptr<arrow::Column> column, int nRecords)
318  {
319  if (column->length() != nRecords) {
320  std::string msg = "Column ";
321  msg += column->name() + " has a different number of entries.";
322  throw std::runtime_error(msg);
323  }
324  };
325 
326  /// For the moment we support only a few native types.
327  auto verifyColumnType = [](std::shared_ptr<arrow::Column> column) {
328  auto verifyType = std::make_unique<VerifyValidColumnType>();
329  auto result = column->type()->Accept(verifyType.get());
330  if (result.ok() == false) {
331  std::string msg = "Column ";
332  msg += column->name() + " contains an unsupported type.";
333  throw std::runtime_error(msg);
334  }
335  };
336 
337  /// This is used to create an index between the columnId
338  /// and the associated getter.
339  auto addColumnToGetterIndex = [&index](int columnId)
340  {
341  index.push_back(std::make_pair(columnId, index.size()));
342  };
343 
344  /// Assuming we can get called more than once, we need to
345  /// reset the getter index each time.
346  auto resetGetterIndex = [&index]() { index.clear(); };
347 
348  /// This is what initialization actually does
349  filterWantedColumns();
350  resetGetterIndex();
351  auto nRecords = getRecordsFirstColumn();
352  for (auto &columnName : fColumnNames) {
353  auto columnIdx = fTable->schema()->GetFieldIndex(columnName);
354  addColumnToGetterIndex(columnIdx);
355 
356  auto column = fTable->column(columnIdx);
357  verifyColumnSize(column, nRecords);
358  verifyColumnType(column);
359  }
360 }
361 
362 ////////////////////////////////////////////////////////////////////////
363 /// Destructor.
365 {
366 }
367 
368 const std::vector<std::string> &RArrowDS::GetColumnNames() const
369 {
370  return fColumnNames;
371 }
372 
373 std::vector<std::pair<ULong64_t, ULong64_t>> RArrowDS::GetEntryRanges()
374 {
375  auto entryRanges(std::move(fEntryRanges)); // empty fEntryRanges
376  return entryRanges;
377 }
378 
379 std::string RArrowDS::GetTypeName(std::string_view colName) const
380 {
381  auto field = fTable->schema()->GetFieldByName(std::string(colName));
382  if (!field) {
383  std::string msg = "The dataset does not have column ";
384  msg += colName;
385  throw std::runtime_error(msg);
386  }
387  RDFTypeNameGetter typeGetter;
388  auto status = field->type()->Accept(&typeGetter);
389  if (status.ok() == false) {
390  std::string msg = "RArrowDS does not support a column of type ";
391  msg += field->type()->name();
392  throw std::runtime_error(msg);
393  }
394  return typeGetter.result();
395 }
396 
398 {
399  auto field = fTable->schema()->GetFieldByName(std::string(colName));
400  if (!field) {
401  return false;
402  }
403  return true;
404 }
405 
406 bool RArrowDS::SetEntry(unsigned int slot, ULong64_t entry)
407 {
408  for (auto link : fGetterIndex) {
409  auto column = fTable->column(link.first);
410  auto &getter = fValueGetters[link.second];
411  getter->SetEntry(slot, entry);
412  }
413  return true;
414 }
415 
416 void RArrowDS::InitSlot(unsigned int slot, ULong64_t entry)
417 {
418  for (auto link : fGetterIndex) {
419  auto column = fTable->column(link.first);
420  auto &getter = fValueGetters[link.second];
421  getter->UncachedSlotLookup(slot, entry);
422  }
423 }
424 
425 void RArrowDS::SetNSlots(unsigned int nSlots)
426 {
427  assert(0U == fNSlots && "Setting the number of slots even if the number of slots is different from zero.");
428 
429  // We dump all the previous getters structures and we rebuild it.
430  auto nColumns = fGetterIndex.size();
431  auto &outNSlots = fNSlots;
432  auto &ranges = fEntryRanges;
433  auto &table = fTable;
434  auto &columnNames = fColumnNames;
435 
436  fValueGetters.clear();
437  for (size_t ci = 0; ci != nColumns; ++ci) {
438  auto chunkedArray = fTable->column(fGetterIndex[ci].first)->data();
439  fValueGetters.emplace_back(std::make_unique<ROOT::Internal::RDF::TValueGetter>(nSlots, chunkedArray->chunks()));
440  }
441 
442  // We use the same logic as the ROOTDS.
443  auto splitInEqualRanges = [&outNSlots, &ranges](int nRecords, unsigned int newNSlots)
444  {
445  ranges.clear();
446  outNSlots = newNSlots;
447  const auto chunkSize = nRecords / outNSlots;
448  const auto remainder = 1U == outNSlots ? 0 : nRecords % outNSlots;
449  auto start = 0UL;
450  auto end = 0UL;
451  for (auto i : ROOT::TSeqU(outNSlots)) {
452  start = end;
453  end += chunkSize;
454  ranges.emplace_back(start, end);
455  (void)i;
456  }
457  ranges.back().second += remainder;
458  };
459 
460  auto getNRecords = [&table, &columnNames]()->int
461  {
462  auto index = table->schema()->GetFieldIndex(columnNames.front());
463  return table->column(index)->length();
464  };
465 
466  auto nRecords = getNRecords();
467  splitInEqualRanges(nRecords, nSlots);
468 }
469 
470 /// This needs to return a pointer to the pointer each value getter
471 /// will point to.
472 std::vector<void *> RArrowDS::GetColumnReadersImpl(std::string_view colName, const std::type_info &)
473 {
474  auto &index = fGetterIndex;
475  auto findGetterIndex = [&index](unsigned int column)
476  {
477  for (auto &entry : index) {
478  if (entry.first == column) {
479  return entry.second;
480  }
481  }
482  throw std::runtime_error("No column found at index " + std::to_string(column));
483  };
484 
485  const int columnIdx = fTable->schema()->GetFieldIndex(std::string(colName));
486  const int getterIdx = findGetterIndex(columnIdx);
487  assert(getterIdx != -1);
488  assert((unsigned int)getterIdx < fValueGetters.size());
489  return fValueGetters[getterIdx]->SlotPtrs();
490 }
491 
493 {
494 }
495 
496 /// Creates a RDataFrame using an arrow::Table as input.
497 /// \param[in] table the arrow Table to observe.
498 /// \param[in] columnNames the name of the columns to use
499 /// In case columnNames is empty, we use all the columns found in the table
500 RDataFrame MakeArrowDataFrame(std::shared_ptr<arrow::Table> table, std::vector<std::string> const &columnNames)
501 {
502  ROOT::RDataFrame tdf(std::make_unique<RArrowDS>(table, columnNames));
503  return tdf;
504 }
505 
506 } // namespace RDF
507 
508 } // namespace ROOT
Namespace for new ROOT classes and functions.
Definition: StringConv.hxx:21
bool HasColumn(std::string_view colName) const override
Checks if the dataset has a certain column.
Definition: RArrowDS.cxx:397
array (ordered collection of values)
RDataFrame MakeArrowDataFrame(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Factory method to create a Apache Arrow RDataFrame.
Definition: RArrowDS.cxx:500
RArrowDS(std::shared_ptr< arrow::Table > table, std::vector< std::string > const &columns)
Constructor to create an Arrow RDataSource for RDataFrame.
Definition: RArrowDS.cxx:289
std::string GetTypeName(std::string_view colName) const override
Type of a column as a string, e.g.
Definition: RArrowDS.cxx:379
bool SetEntry(unsigned int slot, ULong64_t entry) override
Advance the "cursors" returned by GetColumnReaders to the selected entry for a particular slot...
Definition: RArrowDS.cxx:406
std::vector< void * > GetColumnReadersImpl(std::string_view name, const std::type_info &type) override
This needs to return a pointer to the pointer each value getter will point to.
Definition: RArrowDS.cxx:472
std::vector< std::pair< ULong64_t, ULong64_t > > fEntryRanges
Definition: RArrowDS.hxx:26
const std::vector< std::string > & GetColumnNames() const override
Returns a reference to the collection of the dataset&#39;s column names.
Definition: RArrowDS.cxx:368
std::vector< std::unique_ptr< ROOT::Internal::RDF::TValueGetter > > fValueGetters
Definition: RArrowDS.hxx:31
void InitSlot(unsigned int slot, ULong64_t firstEntry) override
Convenience method called at the start of the data processing associated to a slot.
Definition: RArrowDS.cxx:416
void SetNSlots(unsigned int nSlots) override
Inform RDataSource of the number of processing slots (i.e.
Definition: RArrowDS.cxx:425
ROOT&#39;s RDataFrame offers a high level interface for analyses of data stored in TTrees, CSV&#39;s and other data formats.
Definition: RDataFrame.hxx:42
~RArrowDS()
Destructor.
Definition: RArrowDS.cxx:364
A pseudo container class which is a generator of indices.
Definition: TSeq.hxx:66
unsigned long long ULong64_t
Definition: RtypesCore.h:70
basic_string_view< char > string_view
Definition: RStringView.hxx:35
std::shared_ptr< arrow::Table > fTable
Definition: RArrowDS.hxx:25
typedef void((*Func_t)())
void Initialise() override
Convenience method called before starting an event-loop.
Definition: RArrowDS.cxx:492
Definition: first.py:1
char name[80]
Definition: TGX11.cxx:109
std::vector< std::string > fColumnNames
Definition: RArrowDS.hxx:27
std::vector< std::pair< ULong64_t, ULong64_t > > GetEntryRanges() override
Return ranges of entries to distribute to tasks.
Definition: RArrowDS.cxx:373
std::vector< std::pair< size_t, size_t > > fGetterIndex
Definition: RArrowDS.hxx:30