Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026
6// Author: Silia Taider, CERN 02/2026
7
8/*************************************************************************
9 * Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. *
10 * All rights reserved. *
11 * *
12 * For the licensing terms see $ROOTSYS/LICENSE. *
13 * For the list of contributors see $ROOTSYS/README/CREDITS. *
14 *************************************************************************/
15
16#ifndef ROOT_INTERNAL_ML_RBATCHGENERATOR
17#define ROOT_INTERNAL_ML_RBATCHGENERATOR
18
19#include <condition_variable>
20#include <memory>
21#include <mutex>
22#include <string>
23#include <thread>
24#include <vector>
25
31#include "ROOT/ML/RSampler.hxx"
33
34// Empty namespace to create a hook for the Pythonization
36}
37
39/**
40 \class ROOT::Experimental::Internal::ML::RBatchGenerator
41\brief
42
43In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chunks (see
44RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset
45in an RDataFrame.
46*/
47
48template <typename... Args>
50private:
51 std::vector<std::string> fCols;
52 std::vector<std::size_t> fVecSizes;
53 std::size_t fChunkSize;
54 std::size_t fMaxChunks;
55 std::size_t fBatchSize;
56 std::size_t fBlockSize;
57 std::size_t fSetSeed;
58
60
61 std::unique_ptr<RDatasetLoader<Args...>> fDatasetLoader;
62 std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
63 std::unique_ptr<RBatchLoader> fTrainingBatchLoader;
64 std::unique_ptr<RBatchLoader> fValidationBatchLoader;
65 std::unique_ptr<RSampler> fTrainingSampler;
66 std::unique_ptr<RSampler> fValidationSampler;
67
68 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators;
69
70 std::vector<ROOT::RDF::RNode> fRdfs;
71
72 std::unique_ptr<std::thread> fLoadingThread;
73 std::condition_variable fLoadingCondition;
74 std::mutex fLoadingMutex;
75
76 std::size_t fTrainingChunkNum{0};
77 std::size_t fValidationChunkNum{0};
78
82 std::string fSampleType;
85
86 bool fIsActive{false}; // Whether the loading thread is active
88
89 bool fEpochActive{false};
92
95
96 std::size_t fNumTrainingChunks;
98
99 // flattened buffers for chunks and temporary tensors (rows * cols)
100 std::vector<RFlat2DMatrix> fTrainingDatasets;
101 std::vector<RFlat2DMatrix> fValidationDatasets;
102
105
108
110
112
113public:
114 RBatchGenerator(const std::vector<ROOT::RDF::RNode> &rdfs, const std::size_t chunkSize, const std::size_t blockSize,
115 const std::size_t batchSize, const std::vector<std::string> &cols,
116 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
117 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
118 bool dropRemainder = true, const std::size_t setSeed = 0, bool loadEager = false,
119 std::string sampleType = "", float sampleRatio = 1.0, bool replacement = false)
120
121 : fRdfs(rdfs),
122 fCols(cols),
125 fBlockSize(blockSize),
126 fBatchSize(batchSize),
137 {
138 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle, fSetSeed);
139
140 if (fLoadEager) {
141 fDatasetLoader = std::make_unique<RDatasetLoader<Args...>>(fRdfs, fValidationSplit, fCols, fVecSizes,
143 // split the datasets and extract the training and validation datasets
145
146 if (fSampleType == "") {
147 fDatasetLoader->ConcatenateDatasets();
148
149 fTrainingDataset = fDatasetLoader->GetTrainingDataset();
150 fValidationDataset = fDatasetLoader->GetValidationDataset();
151
152 fNumTrainingEntries = fDatasetLoader->GetNumTrainingEntries();
153 fNumValidationEntries = fDatasetLoader->GetNumValidationEntries();
154 }
155
156 else {
157 fTrainingDatasets = fDatasetLoader->GetTrainingDatasets();
158 fValidationDatasets = fDatasetLoader->GetValidationDatasets();
159
162 fValidationSampler = std::make_unique<RSampler>(fValidationDatasets, fSampleType, fSampleRatio,
164
165 fNumTrainingEntries = fTrainingSampler->GetNumEntries();
166 fNumValidationEntries = fValidationSampler->GetNumEntries();
167 }
168 }
169
170 else {
171 fChunkLoader = std::make_unique<RChunkLoader<Args...>>(fRdfs[0], fChunkSize, fBlockSize, fValidationSplit,
173
174 // split the dataset into training and validation sets
175 fChunkLoader->SplitDataset();
176
177 fNumTrainingEntries = fChunkLoader->GetNumTrainingEntries();
178 fNumValidationEntries = fChunkLoader->GetNumValidationEntries();
179
180 // number of training and validation chunks, calculated in RChunkConstructor
181 fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
182 fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
183 }
184
185 fTrainingBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fLoadingMutex, fLoadingCondition,
187 fValidationBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fCols, fLoadingMutex, fLoadingCondition,
189 }
190
192
194 {
195 {
196 std::lock_guard<std::mutex> lock(fLoadingMutex);
197 if (!fIsActive)
198 return;
199 fIsActive = false;
200 }
201
202 fLoadingCondition.notify_all();
203
204 if (fLoadingThread) {
205 if (fLoadingThread->joinable()) {
206 fLoadingThread->join();
207 }
208 }
209
210 fLoadingThread.reset();
211 }
212
213 /// \brief Activate the loading process by spawning the loading thread.
214 void Activate()
215 {
216 {
217 std::lock_guard<std::mutex> lock(fLoadingMutex);
218 if (fIsActive)
219 return;
220
221 fIsActive = true;
222 }
223
224 if (fLoadEager) {
225 return;
226 }
227
228 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
229 }
230
231 /// \brief Activate the training epoch by starting the batchloader.
233 {
234 {
235 std::lock_guard<std::mutex> lock(fLoadingMutex);
238 }
239
240 fTrainingBatchLoader->Activate();
241 fLoadingCondition.notify_all();
242 }
243
245 {
246 {
247 std::lock_guard<std::mutex> lock(fLoadingMutex);
248 fTrainingEpochActive = false;
249 }
250
251 fTrainingBatchLoader->Reset();
252 fTrainingBatchLoader->DeActivate();
253 fLoadingCondition.notify_all();
254 }
255
257 {
258 {
259 std::lock_guard<std::mutex> lock(fLoadingMutex);
262 }
263
264 fValidationBatchLoader->Activate();
265 fLoadingCondition.notify_all();
266 }
267
269 {
270 {
271 std::lock_guard<std::mutex> lock(fLoadingMutex);
273 }
274
275 fValidationBatchLoader->Reset();
276 fValidationBatchLoader->DeActivate();
277 fLoadingCondition.notify_all();
278 }
279
280 /// \brief Main loop for loading chunks and creating batches.
281 /// The producer (loading thread) will keep loading chunks and creating batches until the end of the epoch is
282 /// reached, or the generator is deactivated.
284 {
285 // Set minimum number of batches to keep in the queue before producer goes to work.
286 // This is to ensure that the producer will get a chance to work if the consumer is too fast and drains the queue
287 // quickly. With this, the maximum queue size will be approximately fChunkSize*1.5.
288 // TODO(staider): improve this heuristic by taking into consideration a "maximum number of batches in memory" set
289 // by the user.
290 const std::size_t kMinQueuedBatches = std::max<std::size_t>(1, (fChunkSize / fBatchSize) / 2);
291
292 std::unique_lock<std::mutex> lock(fLoadingMutex);
293
294 while (true) {
295 // Wait until we have work or shutdown
296 fLoadingCondition.wait(lock, [&] {
299 });
300
301 if (!fIsActive)
302 break;
303
304 // Helper: check if validation queue below watermark and needs the producer
305 auto validationEmpty = [&] {
307 return false;
308 if (fValidationBatchLoader->isProducerDone())
309 return false;
310 return fValidationBatchLoader->GetNumBatchQueue() < kMinQueuedBatches;
311 };
312
313 // -- TRAINING --
315 while (true) {
316 // Stop conditions (shutdown or epoch end)
318 break;
319
320 // No more chunks to load: signal consumers
322 fTrainingBatchLoader->MarkProducerDone();
323 break;
324 }
325
326 // In the case of training prefetching, we could start requesting data for the next training loop while
327 // validation is active and might need data. To avoid getting stuck in the training loop, we check if the
328 // validation queue is below watermark and if so, we break out of the training loop.
329 if (validationEmpty()) {
330 break;
331 }
332
333 // If queue is not empty, wait until it drains below watermark, or validation needs data, or we are
334 // deactivated.
335 if (fTrainingBatchLoader->GetNumBatchQueue() >= kMinQueuedBatches) {
336 fLoadingCondition.wait(lock, [&] {
337 return !fIsActive || !fTrainingEpochActive ||
338 fTrainingBatchLoader->GetNumBatchQueue() < kMinQueuedBatches || validationEmpty();
339 });
340 continue;
341 }
342
343 // Claim chunk under lock
344 const std::size_t chunkIdx = fTrainingChunkNum++;
345 const bool isLastTrainChunk = (chunkIdx == fNumTrainingChunks - 1);
346
347 // Release lock while reading and loading data to allow the consumer to access the queue freely in
348 // parallel. The loading thread re-acquires the lock in CreateBatches when it needs to push batches to
349 // the queue.
350 lock.unlock();
351 fChunkLoader->LoadTrainingChunk(fTrainChunkTensor, chunkIdx);
353 lock.lock();
354 }
355 }
356
357 // -- VALIDATION --
359 while (true) {
360 // Stop conditions (shutdown or epoch end)
362 break;
363
364 // No more chunks to load: signal consumers
366 fValidationBatchLoader->MarkProducerDone();
367 break;
368 }
369
370 // If queue is not hungry, wait until it drains below watermark, or we are deactivated
371 if (fValidationBatchLoader->GetNumBatchQueue() >= kMinQueuedBatches) {
372 fLoadingCondition.wait(lock, [&] {
373 return !fIsActive || !fValidationEpochActive ||
374 fValidationBatchLoader->GetNumBatchQueue() < kMinQueuedBatches;
375 });
376 continue;
377 }
378
379 // Claim chunk under lock
380 const std::size_t chunkIdx = fValidationChunkNum++;
382
383 // Release lock while working
384 lock.unlock();
385 fChunkLoader->LoadValidationChunk(fValidationChunkTensor, chunkIdx);
387 lock.lock();
388 }
389 }
390 }
391 }
392
393 /// \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see
394 /// RBatchLoader)
396 {
397 fTrainingBatchLoader->Activate();
398
399 if (fLoadEager) {
400 if (fSampleType == "") {
402 }
403
404 else {
406 }
407
408 fTrainingBatchLoader->CreateBatches(fSampledTrainingDataset, true);
409 fTrainingBatchLoader->MarkProducerDone();
410 } else {
411 fChunkLoader->CreateTrainingChunksIntervals();
412 }
413 }
414
415 /// \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches
416 /// (see RBatchLoader)
418 {
419 fValidationBatchLoader->Activate();
420
421 if (fLoadEager) {
422 if (fSampleType == "") {
424 }
425
426 else {
428 }
429
431 fValidationBatchLoader->MarkProducerDone();
432 }
433
434 else {
435 fChunkLoader->CreateValidationChunksIntervals();
436 }
437 }
438
439 /// \brief Loads a training batch from the queue
441 {
442 // Get next batch if available
443 return fTrainingBatchLoader->GetBatch();
444 }
445
446 /// \brief Loads a validation batch from the queue
448 {
449 // Get next batch if available
450 return fValidationBatchLoader->GetBatch();
451 }
452
453 std::size_t NumberOfTrainingBatches() { return fTrainingBatchLoader->GetNumBatches(); }
454 std::size_t NumberOfValidationBatches() { return fValidationBatchLoader->GetNumBatches(); }
455
456 std::size_t TrainRemainderRows() { return fTrainingBatchLoader->GetNumRemainderRows(); }
457 std::size_t ValidationRemainderRows() { return fValidationBatchLoader->GetNumRemainderRows(); }
458
459 bool IsActive()
460 {
461 std::lock_guard<std::mutex> lock(fLoadingMutex);
462 return fIsActive;
463 }
464
466 {
467 std::lock_guard<std::mutex> lock(fLoadingMutex);
469 }
470
472 {
473 std::lock_guard<std::mutex> lock(fLoadingMutex);
475 }
476};
477
478} // namespace ROOT::Experimental::Internal::ML
479
480#endif // ROOT_INTERNAL_ML_RBATCHGENERATOR
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chu...
void ActivateTrainingEpoch()
Activate the training epoch by starting the batchloader.
void CreateValidationBatches()
Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batche...
std::unique_ptr< RBatchLoader > fTrainingBatchLoader
RFlat2DMatrix GetTrainBatch()
Loads a training batch from the queue.
void CreateTrainBatches()
Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RB...
std::unique_ptr< RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< RDatasetLoader< Args... > > fDatasetLoader
std::unique_ptr< RFlat2DMatrixOperators > fTensorOperators
RBatchGenerator(const std::vector< ROOT::RDF::RNode > &rdfs, const std::size_t chunkSize, const std::size_t blockSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, bool shuffle=true, bool dropRemainder=true, const std::size_t setSeed=0, bool loadEager=false, std::string sampleType="", float sampleRatio=1.0, bool replacement=false)
void Activate()
Activate the loading process by spawning the loading thread.
std::unique_ptr< RBatchLoader > fValidationBatchLoader
void LoadChunks()
Main loop for loading chunks and creating batches.
RFlat2DMatrix GetValidationBatch()
Loads a validation batch from the queue.
Building and loading the chunks from the blocks and chunks constructed in RChunkConstructor.
void SplitDatasets()
Split the dataframes in a training and validation dataset.
Wrapper around ROOT::RVec<float> representing a 2D matrix.