Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5// Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026
6
7/*************************************************************************
8 * Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. *
9 * All rights reserved. *
10 * *
11 * For the licensing terms see $ROOTSYS/LICENSE. *
12 * For the list of contributors see $ROOTSYS/README/CREDITS. *
13 *************************************************************************/
14
15#ifndef TMVA_RBATCHGENERATOR
16#define TMVA_RBATCHGENERATOR
17
22
26#include "TROOT.h"
27
28#include <cmath>
29#include <memory>
30#include <mutex>
31#include <random>
32#include <thread>
33#include <variant>
34#include <vector>
35
36namespace TMVA {
37namespace Experimental {
38namespace Internal {
39
40// clang-format off
41/**
42\class ROOT::TMVA::Experimental::Internal::RBatchGenerator
43\ingroup tmva
44\brief
45
46In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chunks (see RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset in an RDataFrame.
47*/
48
49template <typename... Args>
51private:
52 std::vector<std::string> fCols;
53 std::vector<std::size_t> fVecSizes;
54 // clang-format on
55 std::size_t fChunkSize;
56 std::size_t fMaxChunks;
57 std::size_t fBatchSize;
58 std::size_t fBlockSize;
59 std::size_t fSetSeed;
60
62
63 std::unique_ptr<RDatasetLoader<Args...>> fDatasetLoader;
64 std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
65 std::unique_ptr<RBatchLoader> fTrainingBatchLoader;
66 std::unique_ptr<RBatchLoader> fValidationBatchLoader;
67 std::unique_ptr<RSampler> fTrainingSampler;
68 std::unique_ptr<RSampler> fValidationSampler;
69
70 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators;
71
72 std::vector<ROOT::RDF::RNode> f_rdfs;
73
74 std::unique_ptr<std::thread> fLoadingThread;
75
76 std::size_t fTrainingChunkNum;
78
79 std::mutex fIsActiveMutex;
80
84 std::string fSampleType;
87
88 bool fIsActive{false}; // Whether the loading thread is active
90
91 bool fEpochActive{false};
94
97
98 std::size_t fNumTrainingChunks;
100
101 // flattened buffers for chunks and temporary tensors (rows * cols)
102 std::vector<RFlat2DMatrix> fTrainingDatasets;
103 std::vector<RFlat2DMatrix> fValidationDatasets;
104
107
110
113
116
117public:
118 RBatchGenerator(const std::vector<ROOT::RDF::RNode> &rdfs, const std::size_t chunkSize, const std::size_t blockSize,
119 const std::size_t batchSize, const std::vector<std::string> &cols,
120 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
121 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
122 bool dropRemainder = true, const std::size_t setSeed = 0, bool loadEager = false,
123 std::string sampleType = "", float sampleRatio = 1.0, bool replacement = false)
124
125 : f_rdfs(rdfs),
126 fCols(cols),
129 fBlockSize(blockSize),
130 fBatchSize(batchSize),
141 {
142 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle, fSetSeed);
143
144 if (fLoadEager) {
145 fDatasetLoader = std::make_unique<RDatasetLoader<Args...>>(f_rdfs, fValidationSplit, fCols, fVecSizes,
147 // split the datasets and extract the training and validation datasets
149
150 if (fSampleType == "") {
151 fDatasetLoader->ConcatenateDatasets();
152
153 fTrainingDataset = fDatasetLoader->GetTrainingDataset();
154 fValidationDataset = fDatasetLoader->GetValidationDataset();
155
156 fNumTrainingEntries = fDatasetLoader->GetNumTrainingEntries();
157 fNumValidationEntries = fDatasetLoader->GetNumValidationEntries();
158 }
159
160 else {
161 fTrainingDatasets = fDatasetLoader->GetTrainingDatasets();
162 fValidationDatasets = fDatasetLoader->GetValidationDatasets();
163
166 fValidationSampler = std::make_unique<RSampler>(fValidationDatasets, fSampleType, fSampleRatio,
168
169 fNumTrainingEntries = fTrainingSampler->GetNumEntries();
170 fNumValidationEntries = fValidationSampler->GetNumEntries();
171 }
172 }
173
174 else {
175 fChunkLoader = std::make_unique<RChunkLoader<Args...>>(f_rdfs[0], fChunkSize, fBlockSize, fValidationSplit,
177
178 // split the dataset into training and validation sets
179 fChunkLoader->SplitDataset();
180
181 fNumTrainingEntries = fChunkLoader->GetNumTrainingEntries();
182 fNumValidationEntries = fChunkLoader->GetNumValidationEntries();
183
184 // number of training and validation chunks, calculated in RChunkConstructor
185 fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
186 fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
187 }
188
190 std::make_unique<RBatchLoader>(fBatchSize, fCols, fVecSizes, fNumTrainingEntries, fDropRemainder);
192 std::make_unique<RBatchLoader>(fBatchSize, fCols, fVecSizes, fNumValidationEntries, fDropRemainder);
193 }
194
196
198 {
199 {
200 std::lock_guard<std::mutex> lock(fIsActiveMutex);
201 fIsActive = false;
202 }
203
204 fTrainingBatchLoader->DeActivate();
205 fValidationBatchLoader->DeActivate();
206
207 if (fLoadingThread) {
208 if (fLoadingThread->joinable()) {
209 fLoadingThread->join();
210 }
211 }
212 }
213
214 /// \brief Activate the loading process by starting the batchloader, and
215 /// spawning the loading thread.
216 void Activate()
217 {
218 if (fIsActive)
219 return;
220
221 {
222 std::lock_guard<std::mutex> lock(fIsActiveMutex);
223 fIsActive = true;
224 }
225
226 fTrainingBatchLoader->Activate();
227 fValidationBatchLoader->Activate();
228 // fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
229 }
230
231 void ActivateEpoch() { fEpochActive = true; }
232
233 void DeActivateEpoch() { fEpochActive = false; }
234
236
238
240
242
243 /// \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see
244 /// RBatchLoader)
246 {
248 if (fLoadEager) {
249 if (fSampleType == "") {
251 }
252
253 else {
255 }
256
258 }
259
260 else {
261 fChunkLoader->CreateTrainingChunksIntervals();
266 }
267 }
268
269 /// \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches
270 /// (see RBatchLoader)
272 {
274 if (fLoadEager) {
275 if (fSampleType == "") {
277 }
278
279 else {
281 }
282
284 }
285
286 else {
287 fChunkLoader->CreateValidationChunksIntervals();
292 }
293 }
294
295 /// \brief Loads a training batch from the queue
297 {
298 if (!fLoadEager) {
299 auto batchQueue = fTrainingBatchLoader->GetNumBatchQueue();
300
301 // load the next chunk if the queue is empty
307 }
308 }
309 // Get next batch if available
310 return fTrainingBatchLoader->GetBatch();
311 }
312
313 /// \brief Loads a validation batch from the queue
315 {
316 if (!fLoadEager) {
317 auto batchQueue = fValidationBatchLoader->GetNumBatchQueue();
318
319 // load the next chunk if the queue is empty
325 }
326 }
327 // Get next batch if available
328 return fValidationBatchLoader->GetBatch();
329 }
330
331 std::size_t NumberOfTrainingBatches() { return fTrainingBatchLoader->GetNumBatches(); }
332 std::size_t NumberOfValidationBatches() { return fValidationBatchLoader->GetNumBatches(); }
333
334 std::size_t TrainRemainderRows() { return fTrainingBatchLoader->GetNumRemainderRows(); }
335 std::size_t ValidationRemainderRows() { return fValidationBatchLoader->GetNumRemainderRows(); }
336
337 bool IsActive() { return fIsActive; }
339 /// \brief Returns the next batch of validation data if available.
340 /// Returns empty RTensor otherwise.
341};
342
343} // namespace Internal
344} // namespace Experimental
345} // namespace TMVA
346
347#endif // TMVA_RBATCHGENERATOR
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Building and loading the chunks from the blocks and chunks constructed in RChunkConstructor.
std::unique_ptr< RFlat2DMatrixOperators > fTensorOperators
std::unique_ptr< std::thread > fLoadingThread
std::unique_ptr< RBatchLoader > fTrainingBatchLoader
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
RBatchGenerator(const std::vector< ROOT::RDF::RNode > &rdfs, const std::size_t chunkSize, const std::size_t blockSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, bool shuffle=true, bool dropRemainder=true, const std::size_t setSeed=0, bool loadEager=false, std::string sampleType="", float sampleRatio=1.0, bool replacement=false)
void CreateValidationBatches()
Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batche...
void CreateTrainBatches()
Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RB...
std::unique_ptr< RDatasetLoader< Args... > > fDatasetLoader
RFlat2DMatrix GetValidationBatch()
Loads a validation batch from the queue.
std::unique_ptr< RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< RBatchLoader > fValidationBatchLoader
RFlat2DMatrix GetTrainBatch()
Loads a training batch from the queue.
void SplitDatasets()
Split the dataframes in a training and validation dataset.
create variable transformations
Wrapper around ROOT::RVec<float> representing a 2D matrix.