Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1// Author: Dante Niewenhuis, VU Amsterdam 07/2023
2// Author: Kristupas Pranckietis, Vilnius University 05/2024
3// Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
4// Author: Vincenzo Eduardo Padulano, CERN 10/2024
5
6/*************************************************************************
7 * Copyright (C) 1995-2024, Rene Brun and Fons Rademakers. *
8 * All rights reserved. *
9 * *
10 * For the licensing terms see $ROOTSYS/LICENSE. *
11 * For the list of contributors see $ROOTSYS/README/CREDITS. *
12 *************************************************************************/
13
14#ifndef TMVA_RBATCHGENERATOR
15#define TMVA_RBATCHGENERATOR
16
17#include "TMVA/RTensor.hxx"
21#include "TROOT.h"
22
23#include <cmath>
24#include <memory>
25#include <mutex>
26#include <random>
27#include <thread>
28#include <variant>
29#include <vector>
30
31namespace TMVA {
32namespace Experimental {
33namespace Internal {
34
35template <typename... Args>
37private:
38 std::mt19937 fRng;
39 std::mt19937 fFixedRng;
40 std::random_device::result_type fFixedSeed;
41
42 std::size_t fChunkSize;
43 std::size_t fMaxChunks;
44 std::size_t fBatchSize;
45 std::size_t fNumEntries;
46
48
49 std::variant<std::shared_ptr<RChunkLoader<Args...>>, std::shared_ptr<RChunkLoaderFilters<Args...>>> fChunkLoader;
50
51 std::unique_ptr<RBatchLoader> fBatchLoader;
52
53 std::unique_ptr<std::thread> fLoadingThread;
54
55 std::unique_ptr<TMVA::Experimental::RTensor<float>> fChunkTensor;
56
58
59 std::mutex fIsActiveMutex;
60
63 bool fIsActive{false}; // Whether the loading thread is active
66
67public:
68 RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t batchSize,
69 const std::vector<std::string> &cols, const std::size_t numColumns,
70 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
71 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
72 bool dropRemainder = true)
73 : fRng(std::random_device{}()),
74 fFixedSeed(std::uniform_int_distribution<std::random_device::result_type>{}(fRng)),
75 f_rdf(rdf),
76 fChunkSize(chunkSize),
77 fBatchSize(batchSize),
78 fValidationSplit(validationSplit),
79 fMaxChunks(maxChunks),
80 fDropRemainder(dropRemainder),
81 fShuffle(shuffle),
83 fUseWholeFile(maxChunks == 0)
84 {
85
86 // Create tensor to load the chunk into
88 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fChunkSize, numColumns});
89
90 if (fNotFiltered) {
91 fNumEntries = f_rdf.Count().GetValue();
92
94 f_rdf, *fChunkTensor, fChunkSize, cols, vecSizes, vecPadding);
95 } else {
96 auto report = f_rdf.Report();
97 fNumEntries = f_rdf.Count().GetValue();
98 std::size_t numAllEntries = report.begin()->GetAll();
99
101 f_rdf, *fChunkTensor, fChunkSize, cols, fNumEntries, numAllEntries, vecSizes, vecPadding);
102 }
103
104 std::size_t maxBatches = ceil((fChunkSize / fBatchSize) * (1 - fValidationSplit));
105
106 // limits the number of batches that can be contained in the batchqueue based on the chunksize
107 fBatchLoader = std::make_unique<TMVA::Experimental::Internal::RBatchLoader>(*fChunkTensor, fBatchSize, numColumns,
108 maxBatches);
109 }
110
112
113 /// \brief De-activate the loading process by deactivating the batchgenerator
114 /// and joining the loading thread
116 {
117 {
118 std::lock_guard<std::mutex> lock(fIsActiveMutex);
119 fIsActive = false;
120 }
121
122 fBatchLoader->DeActivate();
123
124 if (fLoadingThread) {
125 if (fLoadingThread->joinable()) {
126 fLoadingThread->join();
127 }
128 }
129 }
130
131 /// \brief Activate the loading process by starting the batchloader, and
132 /// spawning the loading thread.
133 void Activate()
134 {
135 if (fIsActive)
136 return;
137
138 {
139 std::lock_guard<std::mutex> lock(fIsActiveMutex);
140 fIsActive = true;
141 }
142
143 fFixedRng.seed(fFixedSeed);
144 fBatchLoader->Activate();
145 // fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
146 if (fNotFiltered) {
147 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunksNoFilters, this);
148 } else {
149 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunksFilters, this);
150 }
151 }
152
153 /// \brief Returns the next batch of training data if available.
154 /// Returns empty RTensor otherwise.
155 /// \return
157 {
158 // Get next batch if available
159 return fBatchLoader->GetTrainBatch();
160 }
161
162 /// \brief Returns the next batch of validation data if available.
163 /// Returns empty RTensor otherwise.
164 /// \return
166 {
167 // Get next batch if available
168 return fBatchLoader->GetValidationBatch();
169 }
170
172 {
173 std::size_t entriesForTraining =
176
177 if (fDropRemainder || !(entriesForTraining % fBatchSize)) {
178 return entriesForTraining / fBatchSize;
179 }
180
181 return entriesForTraining / fBatchSize + 1;
182 }
183
184 /// @brief Return number of training remainder rows
185 /// @return
186 std::size_t TrainRemainderRows()
187 {
188 std::size_t entriesForTraining =
191
192 if (fDropRemainder || !(entriesForTraining % fBatchSize)) {
193 return 0;
194 }
195
196 return entriesForTraining % fBatchSize;
197 }
198
199 /// @brief Calculate number of validation batches and return it
200 /// @return
202 {
203 std::size_t entriesForValidation = (fNumEntries / fChunkSize) * floor(fChunkSize * fValidationSplit) +
205
206 if (fDropRemainder || !(entriesForValidation % fBatchSize)) {
207
208 return entriesForValidation / fBatchSize;
209 }
210
211 return entriesForValidation / fBatchSize + 1;
212 }
213
214 /// @brief Return number of validation remainder rows
215 /// @return
217 {
218 std::size_t entriesForValidation = (fNumEntries / fChunkSize) * floor(fChunkSize * fValidationSplit) +
220
221 if (fDropRemainder || !(entriesForValidation % fBatchSize)) {
222
223 return 0;
224 }
225
226 return entriesForValidation % fBatchSize;
227 }
228
229 /// @brief Load chunks when no filters are applied on rdataframe
231 {
232 for (std::size_t currentChunk = 0, currentEntry = 0;
233 ((currentChunk < fMaxChunks) || fUseWholeFile) && currentEntry < fNumEntries; currentChunk++) {
234
235 // stop the loop when the loading is not active anymore
236 {
237 std::lock_guard<std::mutex> lock(fIsActiveMutex);
238 if (!fIsActive)
239 return;
240 }
241
242 // A pair that consists the proccessed, and passed events while loading the chunk
243 std::size_t report = std::get<std::shared_ptr<RChunkLoader<Args...>>>(fChunkLoader)->LoadChunk(currentEntry);
244 currentEntry += report;
245
246 CreateBatches(report);
247 }
248
249 if (!fDropRemainder) {
250 fBatchLoader->LastBatches();
251 }
252
253 fBatchLoader->DeActivate();
254 }
255
257 {
258 std::size_t currentChunk = 0;
259 for (std::size_t processedEvents = 0, currentRow = 0;
260 ((currentChunk < fMaxChunks) || fUseWholeFile) && processedEvents < fNumEntries; currentChunk++) {
261
262 // stop the loop when the loading is not active anymore
263 {
264 std::lock_guard<std::mutex> lock(fIsActiveMutex);
265 if (!fIsActive)
266 return;
267 }
268
269 // A pair that consists the proccessed, and passed events while loading the chunk
270 std::pair<std::size_t, std::size_t> report =
271 std::get<std::shared_ptr<RChunkLoaderFilters<Args...>>>(fChunkLoader)->LoadChunk(currentRow);
272
273 currentRow += report.first;
274 processedEvents += report.second;
275
276 CreateBatches(report.second);
277 }
278
279 if (currentChunk < fMaxChunks || fUseWholeFile) {
280 CreateBatches(std::get<std::shared_ptr<RChunkLoaderFilters<Args...>>>(fChunkLoader)->LastChunk());
281 }
282
283 if (!fDropRemainder) {
284 fBatchLoader->LastBatches();
285 }
286
287 fBatchLoader->DeActivate();
288 }
289
290 /// \brief Create batches
291 /// \param processedEvents
292 void CreateBatches(std::size_t processedEvents)
293 {
294 auto &&[trainingIndices, validationIndices] = createIndices(processedEvents);
295
296 fBatchLoader->CreateTrainingBatches(trainingIndices);
297 fBatchLoader->CreateValidationBatches(validationIndices);
298 }
299
300 /// \brief split the events of the current chunk into training and validation events, shuffle if needed
301 /// \param events
302 std::pair<std::vector<std::size_t>, std::vector<std::size_t>> createIndices(std::size_t events)
303 {
304 // Create a vector of number 1..events
305 std::vector<std::size_t> row_order = std::vector<std::size_t>(events);
306 std::iota(row_order.begin(), row_order.end(), 0);
307
308 if (fShuffle) {
309 // Shuffle the entry indices at every new epoch
310 std::shuffle(row_order.begin(), row_order.end(), fFixedRng);
311 }
312
313 // calculate the number of events used for validation
314 std::size_t num_validation = floor(events * fValidationSplit);
315
316 // Devide the vector into training and validation and return
317 std::vector<std::size_t> trainingIndices =
318 std::vector<std::size_t>({row_order.begin(), row_order.end() - num_validation});
319 std::vector<std::size_t> validationIndices =
320 std::vector<std::size_t>({row_order.end() - num_validation, row_order.end()});
321
322 if (fShuffle) {
323 std::shuffle(trainingIndices.begin(), trainingIndices.end(), fRng);
324 }
325
326 return std::make_pair(trainingIndices, validationIndices);
327 }
328
329 bool IsActive() { return fIsActive; }
330};
331
332} // namespace Internal
333} // namespace Experimental
334} // namespace TMVA
335
336#endif // TMVA_RBATCHGENERATOR
The public interface to the RDataFrame federation of classes.
RResultPtr< ULong64_t > Count()
Return the number of entries processed (lazy action).
RResultPtr< RCutFlowReport > Report()
Gather filtering statistics.
std::vector< std::string > GetFilterNames()
Returns the names of the filters created.
std::size_t ValidationRemainderRows()
Return number of validation remainder rows.
std::size_t TrainRemainderRows()
Return number of training remainder rows.
void LoadChunksNoFilters()
Load chunks when no filters are applied on rdataframe.
std::unique_ptr< std::thread > fLoadingThread
const TMVA::Experimental::RTensor< float > & GetTrainBatch()
Returns the next batch of training data if available.
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
std::size_t NumberOfValidationBatches()
Calculate number of validation batches and return it.
std::pair< std::vector< std::size_t >, std::vector< std::size_t > > createIndices(std::size_t events)
split the events of the current chunk into training and validation events, shuffle if needed
const TMVA::Experimental::RTensor< float > & GetValidationBatch()
Returns the next batch of validation data if available.
void DeActivate()
De-activate the loading process by deactivating the batchgenerator and joining the loading thread.
std::variant< std::shared_ptr< RChunkLoader< Args... > >, std::shared_ptr< RChunkLoaderFilters< Args... > > > fChunkLoader
std::unique_ptr< RBatchLoader > fBatchLoader
std::random_device::result_type fFixedSeed
void CreateBatches(std::size_t processedEvents)
Create batches.
std::unique_ptr< TMVA::Experimental::RTensor< float > > fChunkTensor
RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::size_t numColumns, const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, bool shuffle=true, bool dropRemainder=true)
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
RVec< PromoteType< T > > ceil(const RVec< T > &v)
Definition RVec.hxx:1867
create variable transformations