Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026
6// Author: Silia Taider, CERN 02/2026
7
8/*************************************************************************
9 * Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. *
10 * All rights reserved. *
11 * *
12 * For the licensing terms see $ROOTSYS/LICENSE. *
13 * For the list of contributors see $ROOTSYS/README/CREDITS. *
14 *************************************************************************/
15
16#ifndef ROOT_INTERNAL_ML_RBATCHGENERATOR
17#define ROOT_INTERNAL_ML_RBATCHGENERATOR
18
21#include "ROOT/ML/RSampler.hxx"
23
27#include "TROOT.h"
28
29#include <cmath>
30#include <memory>
31#include <mutex>
32#include <random>
33#include <thread>
34#include <variant>
35#include <vector>
36
37// Empty namespace to create a hook for the Pythonization
39}
40
42/**
43 \class ROOT::Experimental::Internal::ML::RBatchGenerator
44\brief
45
46In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chunks (see
47RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset
48in an RDataFrame.
49*/
50
51template <typename... Args>
53private:
54 std::vector<std::string> fCols;
55 std::vector<std::size_t> fVecSizes;
56 std::size_t fChunkSize;
57 std::size_t fMaxChunks;
58 std::size_t fBatchSize;
59 std::size_t fBlockSize;
60 std::size_t fSetSeed;
61
63
64 std::unique_ptr<RDatasetLoader<Args...>> fDatasetLoader;
65 std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
66 std::unique_ptr<RBatchLoader> fTrainingBatchLoader;
67 std::unique_ptr<RBatchLoader> fValidationBatchLoader;
68 std::unique_ptr<RSampler> fTrainingSampler;
69 std::unique_ptr<RSampler> fValidationSampler;
70
71 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators;
72
73 std::vector<ROOT::RDF::RNode> fRdfs;
74
75 std::unique_ptr<std::thread> fLoadingThread;
76 std::condition_variable fLoadingCondition;
77 std::mutex fLoadingMutex;
78
79 std::size_t fTrainingChunkNum{0};
80 std::size_t fValidationChunkNum{0};
81
85 std::string fSampleType;
88
89 bool fIsActive{false}; // Whether the loading thread is active
91
92 bool fEpochActive{false};
95
98
99 std::size_t fNumTrainingChunks;
101
102 // flattened buffers for chunks and temporary tensors (rows * cols)
103 std::vector<RFlat2DMatrix> fTrainingDatasets;
104 std::vector<RFlat2DMatrix> fValidationDatasets;
105
108
111
113
115
116public:
117 RBatchGenerator(const std::vector<ROOT::RDF::RNode> &rdfs, const std::size_t chunkSize, const std::size_t blockSize,
118 const std::size_t batchSize, const std::vector<std::string> &cols,
119 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
120 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
121 bool dropRemainder = true, const std::size_t setSeed = 0, bool loadEager = false,
122 std::string sampleType = "", float sampleRatio = 1.0, bool replacement = false)
123
124 : fRdfs(rdfs),
125 fCols(cols),
128 fBlockSize(blockSize),
129 fBatchSize(batchSize),
140 {
141 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle, fSetSeed);
142
143 if (fLoadEager) {
144 fDatasetLoader = std::make_unique<RDatasetLoader<Args...>>(fRdfs, fValidationSplit, fCols, fVecSizes,
146 // split the datasets and extract the training and validation datasets
148
149 if (fSampleType == "") {
150 fDatasetLoader->ConcatenateDatasets();
151
152 fTrainingDataset = fDatasetLoader->GetTrainingDataset();
153 fValidationDataset = fDatasetLoader->GetValidationDataset();
154
155 fNumTrainingEntries = fDatasetLoader->GetNumTrainingEntries();
156 fNumValidationEntries = fDatasetLoader->GetNumValidationEntries();
157 }
158
159 else {
160 fTrainingDatasets = fDatasetLoader->GetTrainingDatasets();
161 fValidationDatasets = fDatasetLoader->GetValidationDatasets();
162
165 fValidationSampler = std::make_unique<RSampler>(fValidationDatasets, fSampleType, fSampleRatio,
167
168 fNumTrainingEntries = fTrainingSampler->GetNumEntries();
169 fNumValidationEntries = fValidationSampler->GetNumEntries();
170 }
171 }
172
173 else {
174 fChunkLoader = std::make_unique<RChunkLoader<Args...>>(fRdfs[0], fChunkSize, fBlockSize, fValidationSplit,
176
177 // split the dataset into training and validation sets
178 fChunkLoader->SplitDataset();
179
180 fNumTrainingEntries = fChunkLoader->GetNumTrainingEntries();
181 fNumValidationEntries = fChunkLoader->GetNumValidationEntries();
182
183 // number of training and validation chunks, calculated in RChunkConstructor
184 fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
185 fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
186 }
187
188 fTrainingBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fLoadingMutex, fLoadingCondition,
190 fValidationBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fLoadingMutex, fLoadingCondition,
192 }
193
195
197 {
198 {
199 std::lock_guard<std::mutex> lock(fLoadingMutex);
200 if (!fIsActive)
201 return;
202 fIsActive = false;
203 }
204
205 fLoadingCondition.notify_all();
206
207 if (fLoadingThread) {
208 if (fLoadingThread->joinable()) {
209 fLoadingThread->join();
210 }
211 }
212
213 fLoadingThread.reset();
214 }
215
216 /// \brief Activate the loading process by spawning the loading thread.
217 void Activate()
218 {
219 {
220 std::lock_guard<std::mutex> lock(fLoadingMutex);
221 if (fIsActive)
222 return;
223
224 fIsActive = true;
225 }
226
227 if (fLoadEager) {
228 return;
229 }
230
231 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
232 }
233
234 /// \brief Activate the training epoch by starting the batchloader.
236 {
237 {
238 std::lock_guard<std::mutex> lock(fLoadingMutex);
241 }
242
243 fTrainingBatchLoader->Activate();
244 fLoadingCondition.notify_all();
245 }
246
248 {
249 {
250 std::lock_guard<std::mutex> lock(fLoadingMutex);
251 fTrainingEpochActive = false;
252 }
253
254 fTrainingBatchLoader->Reset();
255 fTrainingBatchLoader->DeActivate();
256 fLoadingCondition.notify_all();
257 }
258
260 {
261 {
262 std::lock_guard<std::mutex> lock(fLoadingMutex);
265 }
266
267 fValidationBatchLoader->Activate();
268 fLoadingCondition.notify_all();
269 }
270
272 {
273 {
274 std::lock_guard<std::mutex> lock(fLoadingMutex);
276 }
277
278 fValidationBatchLoader->Reset();
279 fValidationBatchLoader->DeActivate();
280 fLoadingCondition.notify_all();
281 }
282
283 /// \brief Main loop for loading chunks and creating batches.
284 /// The producer (loading thread) will keep loading chunks and creating batches until the end of the epoch is
285 /// reached, or the generator is deactivated.
287 {
288 // Set minimum number of batches to keep in the queue before producer goes to work.
289 // This is to ensure that the producer will get a chance to work if the consumer is too fast and drains the queue
290 // quickly. With this, the maximum queue size will be approximately fChunkSize*1.5.
291 // TODO(staider): improve this heuristic by taking into consideration a "maximum number of batches in memory" set
292 // by the user.
293 const std::size_t kMinQueuedBatches = std::max<std::size_t>(1, (fChunkSize / fBatchSize) / 2);
294
295 std::unique_lock<std::mutex> lock(fLoadingMutex);
296
297 while (true) {
298 // Wait until we have work or shutdown
299 fLoadingCondition.wait(lock, [&] {
302 });
303
304 if (!fIsActive)
305 break;
306
307 // Helper: check if validation queue below watermark and needs the producer
308 auto validationEmpty = [&] {
310 return false;
311 if (fValidationBatchLoader->isProducerDone())
312 return false;
313 return fValidationBatchLoader->GetNumBatchQueue() < kMinQueuedBatches;
314 };
315
316 // -- TRAINING --
318 while (true) {
319 // Stop conditions (shutdown or epoch end)
321 break;
322
323 // No more chunks to load: signal consumers
325 fTrainingBatchLoader->MarkProducerDone();
326 break;
327 }
328
329 // In the case of training prefetching, we could start requesting data for the next training loop while
330 // validation is active and might need data. To avoid getting stuck in the training loop, we check if the
331 // validation queue is below watermark and if so, we break out of the training loop.
332 if (validationEmpty()) {
333 break;
334 }
335
336 // If queue is not empty, wait until it drains below watermark, or validation needs data, or we are
337 // deactivated.
338 if (fTrainingBatchLoader->GetNumBatchQueue() >= kMinQueuedBatches) {
339 fLoadingCondition.wait(lock, [&] {
340 return !fIsActive || !fTrainingEpochActive ||
341 fTrainingBatchLoader->GetNumBatchQueue() < kMinQueuedBatches || validationEmpty();
342 });
343 continue;
344 }
345
346 // Claim chunk under lock
347 const std::size_t chunkIdx = fTrainingChunkNum++;
348 const bool isLastTrainChunk = (chunkIdx == fNumTrainingChunks - 1);
349
350 // Release lock while reading and loading data to allow the consumer to access the queue freely in
351 // parallel. The loading thread re-acquires the lock in CreateBatches when it needs to push batches to
352 // the queue.
353 lock.unlock();
354 fChunkLoader->LoadTrainingChunk(fTrainChunkTensor, chunkIdx);
356 lock.lock();
357 }
358 }
359
360 // -- VALIDATION --
362 while (true) {
363 // Stop conditions (shutdown or epoch end)
365 break;
366
367 // No more chunks to load: signal consumers
369 fValidationBatchLoader->MarkProducerDone();
370 break;
371 }
372
373 // If queue is not hungry, wait until it drains below watermark, or we are deactivated
374 if (fValidationBatchLoader->GetNumBatchQueue() >= kMinQueuedBatches) {
375 fLoadingCondition.wait(lock, [&] {
376 return !fIsActive || !fValidationEpochActive ||
377 fValidationBatchLoader->GetNumBatchQueue() < kMinQueuedBatches;
378 });
379 continue;
380 }
381
382 // Claim chunk under lock
383 const std::size_t chunkIdx = fValidationChunkNum++;
385
386 // Release lock while working
387 lock.unlock();
388 fChunkLoader->LoadValidationChunk(fValidationChunkTensor, chunkIdx);
390 lock.lock();
391 }
392 }
393 }
394 }
395
396 /// \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see
397 /// RBatchLoader)
399 {
400 fTrainingBatchLoader->Activate();
401
402 if (fLoadEager) {
403 if (fSampleType == "") {
405 }
406
407 else {
409 }
410
411 fTrainingBatchLoader->CreateBatches(fSampledTrainingDataset, true);
412 fTrainingBatchLoader->MarkProducerDone();
413 } else {
414 fChunkLoader->CreateTrainingChunksIntervals();
415 }
416 }
417
418 /// \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches
419 /// (see RBatchLoader)
421 {
422 fValidationBatchLoader->Activate();
423
424 if (fLoadEager) {
425 if (fSampleType == "") {
427 }
428
429 else {
431 }
432
434 fValidationBatchLoader->MarkProducerDone();
435 }
436
437 else {
438 fChunkLoader->CreateValidationChunksIntervals();
439 }
440 }
441
442 /// \brief Loads a training batch from the queue
444 {
445 // Get next batch if available
446 return fTrainingBatchLoader->GetBatch();
447 }
448
449 /// \brief Loads a validation batch from the queue
451 {
452 // Get next batch if available
453 return fValidationBatchLoader->GetBatch();
454 }
455
456 std::size_t NumberOfTrainingBatches() { return fTrainingBatchLoader->GetNumBatches(); }
457 std::size_t NumberOfValidationBatches() { return fValidationBatchLoader->GetNumBatches(); }
458
459 std::size_t TrainRemainderRows() { return fTrainingBatchLoader->GetNumRemainderRows(); }
460 std::size_t ValidationRemainderRows() { return fValidationBatchLoader->GetNumRemainderRows(); }
461
462 bool IsActive()
463 {
464 std::lock_guard<std::mutex> lock(fLoadingMutex);
465 return fIsActive;
466 }
467
469 {
470 std::lock_guard<std::mutex> lock(fLoadingMutex);
472 }
473
475 {
476 std::lock_guard<std::mutex> lock(fLoadingMutex);
478 }
479};
480
481} // namespace ROOT::Experimental::Internal::ML
482
483#endif // ROOT_INTERNAL_ML_RBATCHGENERATOR
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chu...
void ActivateTrainingEpoch()
Activate the training epoch by starting the batchloader.
void CreateValidationBatches()
Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batche...
std::unique_ptr< RBatchLoader > fTrainingBatchLoader
RFlat2DMatrix GetTrainBatch()
Loads a training batch from the queue.
void CreateTrainBatches()
Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RB...
std::unique_ptr< RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< RDatasetLoader< Args... > > fDatasetLoader
std::unique_ptr< RFlat2DMatrixOperators > fTensorOperators
RBatchGenerator(const std::vector< ROOT::RDF::RNode > &rdfs, const std::size_t chunkSize, const std::size_t blockSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, bool shuffle=true, bool dropRemainder=true, const std::size_t setSeed=0, bool loadEager=false, std::string sampleType="", float sampleRatio=1.0, bool replacement=false)
void Activate()
Activate the loading process by spawning the loading thread.
std::unique_ptr< RBatchLoader > fValidationBatchLoader
void LoadChunks()
Main loop for loading chunks and creating batches.
RFlat2DMatrix GetValidationBatch()
Loads a validation batch from the queue.
Building and loading the chunks from the blocks and chunks constructed in RChunkConstructor.
void SplitDatasets()
Split the dataframes in a training and validation dataset.
Wrapper around ROOT::RVec<float> representing a 2D matrix.