Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator.hxx
Go to the documentation of this file.
1#ifndef TMVA_BATCHGENERATOR
2#define TMVA_BATCHGENERATOR
3
4#include <iostream>
5#include <vector>
6#include <thread>
7#include <memory>
8#include <cmath>
9#include <mutex>
10
11#include "TMVA/RTensor.hxx"
13#include "TMVA/RChunkLoader.hxx"
14#include "TMVA/RBatchLoader.hxx"
15#include "TMVA/Tools.h"
16#include "TRandom3.h"
17#include "TROOT.h"
18
19namespace TMVA {
20namespace Experimental {
21namespace Internal {
22
23template <typename... Args>
25private:
27
28 std::string fFileName;
29 std::string fTreeName;
30
31 std::vector<std::string> fCols;
32 std::string fFilters;
33
34 std::size_t fChunkSize;
35 std::size_t fMaxChunks;
36 std::size_t fBatchSize;
37 std::size_t fMaxBatches;
38 std::size_t fNumColumns;
39 std::size_t fNumEntries;
40 std::size_t fCurrentRow = 0;
41
43
45 std::unique_ptr<TMVA::Experimental::Internal::RBatchLoader> fBatchLoader;
46
47 std::unique_ptr<std::thread> fLoadingThread;
48
49 bool fUseWholeFile = true;
50
51 std::unique_ptr<TMVA::Experimental::RTensor<float>> fChunkTensor;
52 std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
53
54 std::vector<std::vector<std::size_t>> fTrainingIdxs;
55 std::vector<std::vector<std::size_t>> fValidationIdxs;
56
57 // filled batch elements
58 std::mutex fIsActiveLock;
59
60 bool fShuffle = true;
61 bool fIsActive = false;
62
63 std::vector<std::size_t> fVecSizes;
65
66public:
67 RBatchGenerator(const std::string &treeName, const std::string &fileName, const std::size_t chunkSize,
68 const std::size_t batchSize, const std::vector<std::string> &cols, const std::string &filters = "",
69 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
70 const float validationSplit = 0.0, const std::size_t maxChunks = 0, const std::size_t numColumns = 0,
71 bool shuffle = true)
72 : fTreeName(treeName),
73 fFileName(fileName),
74 fChunkSize(chunkSize),
75 fBatchSize(batchSize),
76 fCols(cols),
78 fVecSizes(vecSizes),
79 fVecPadding(vecPadding),
80 fValidationSplit(validationSplit),
81 fMaxChunks(maxChunks),
82 fNumColumns((numColumns != 0) ? numColumns : cols.size()),
83 fShuffle(shuffle),
84 fUseWholeFile(maxChunks == 0)
85 {
86 // limits the number of batches that can be contained in the batchqueue based on the chunksize
88
89 // get the number of fNumEntries in the dataframe
90 std::unique_ptr<TFile> f{TFile::Open(fFileName.c_str())};
91 std::unique_ptr<TTree> t{f->Get<TTree>(fTreeName.c_str())};
93
96 fBatchLoader = std::make_unique<TMVA::Experimental::Internal::RBatchLoader>(fBatchSize, fNumColumns, fMaxBatches);
97
98 // Create tensor to load the chunk into
100 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fChunkSize, fNumColumns});
101 }
102
104
105 /// \brief De-activate the loading process by deactivating the batchgenerator
106 /// and joining the loading thread
108 {
109 {
110 std::lock_guard<std::mutex> lock(fIsActiveLock);
111 fIsActive = false;
112 }
113
114 fBatchLoader->DeActivate();
115
116 if (fLoadingThread) {
117 if (fLoadingThread->joinable()) {
118 fLoadingThread->join();
119 }
120 }
121 }
122
123 /// \brief Activate the loading process by starting the batchloader, and
124 /// spawning the loading thread.
125 void Activate()
126 {
127 if (fIsActive)
128 return;
129
130 {
131 std::lock_guard<std::mutex> lock(fIsActiveLock);
132 fIsActive = true;
133 }
134
135 fCurrentRow = 0;
136 fBatchLoader->Activate();
137 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
138 }
139
140 /// \brief Returns the next batch of training data if available.
141 /// Returns empty RTensor otherwise.
142 /// \return
144 {
145 // Get next batch if available
146 return fBatchLoader->GetTrainBatch();
147 }
148
149 /// \brief Returns the next batch of validation data if available.
150 /// Returns empty RTensor otherwise.
151 /// \return
153 {
154 // Get next batch if available
155 return fBatchLoader->GetValidationBatch();
156 }
157
158 bool HasTrainData() { return fBatchLoader->HasTrainData(); }
159
160 bool HasValidationData() { return fBatchLoader->HasValidationData(); }
161
163 {
164 for (std::size_t current_chunk = 0; ((current_chunk < fMaxChunks) || fUseWholeFile) && fCurrentRow < fNumEntries;
165 current_chunk++) {
166
167 // stop the loop when the loading is not active anymore
168 {
169 std::lock_guard<std::mutex> lock(fIsActiveLock);
170 if (!fIsActive)
171 return;
172 }
173
174 // A pair that consists the proccessed, and passed events while loading the chunk
175 std::pair<std::size_t, std::size_t> report = fChunkLoader->LoadChunk(*fChunkTensor, fCurrentRow);
176 fCurrentRow += report.first;
177
178 CreateBatches(current_chunk, report.second);
179
180 // Stop loading if the number of processed events is smaller than the desired chunk size
181 if (report.first < fChunkSize) {
182 break;
183 }
184 }
185
186 fBatchLoader->DeActivate();
187 }
188
189 /// \brief Create batches for the current_chunk.
190 /// \param currentChunk
191 /// \param processedEvents
192 void CreateBatches(std::size_t currentChunk, std::size_t processedEvents)
193 {
194
195 // Check if the indices in this chunk where already split in train and validations
196 if (fTrainingIdxs.size() > currentChunk) {
197 fBatchLoader->CreateTrainingBatches(*fChunkTensor, fTrainingIdxs[currentChunk], fShuffle);
198 } else {
199 // Create the Validation batches if this is not the first epoch
200 createIdxs(processedEvents);
201 fBatchLoader->CreateTrainingBatches(*fChunkTensor, fTrainingIdxs[currentChunk], fShuffle);
202 fBatchLoader->CreateValidationBatches(*fChunkTensor, fValidationIdxs[currentChunk]);
203 }
204 }
205
206 /// \brief plit the events of the current chunk into validation and training events
207 /// \param processedEvents
208 void createIdxs(std::size_t processedEvents)
209 {
210 // Create a vector of number 1..processedEvents
211 std::vector<std::size_t> row_order = std::vector<std::size_t>(processedEvents);
212 std::iota(row_order.begin(), row_order.end(), 0);
213
214 if (fShuffle) {
215 std::shuffle(row_order.begin(), row_order.end(), fRng);
216 }
217
218 // calculate the number of events used for validation
219 std::size_t num_validation = ceil(processedEvents * fValidationSplit);
220
221 // Devide the vector into training and validation
222 std::vector<std::size_t> valid_idx({row_order.begin(), row_order.begin() + num_validation});
223 std::vector<std::size_t> train_idx({row_order.begin() + num_validation, row_order.end()});
224
225 fTrainingIdxs.push_back(train_idx);
226 fValidationIdxs.push_back(valid_idx);
227 }
228
229 void StartValidation() { fBatchLoader->StartValidation(); }
230 bool IsActive() { return fIsActive; }
231};
232
233} // namespace Internal
234} // namespace Experimental
235} // namespace TMVA
236
237#endif // TMVA_BATCHGENERATOR
#define f(i)
Definition RSha256.hxx:104
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
const char * filters[]
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4082
void CreateBatches(std::size_t currentChunk, std::size_t processedEvents)
Create batches for the current_chunk.
std::unique_ptr< std::thread > fLoadingThread
const TMVA::Experimental::RTensor< float > & GetTrainBatch()
Returns the next batch of training data if available.
void Activate()
Activate the loading process by starting the batchloader, and spawning the loading thread.
const TMVA::Experimental::RTensor< float > & GetValidationBatch()
Returns the next batch of validation data if available.
std::vector< std::vector< std::size_t > > fTrainingIdxs
void DeActivate()
De-activate the loading process by deactivating the batchgenerator and joining the loading thread.
void createIdxs(std::size_t processedEvents)
plit the events of the current chunk into validation and training events
std::unique_ptr< TMVA::Experimental::Internal::RBatchLoader > fBatchLoader
std::vector< std::vector< std::size_t > > fValidationIdxs
TMVA::RandomGenerator< TRandom3 > fRng
RBatchGenerator(const std::string &treeName, const std::string &fileName, const std::size_t chunkSize, const std::size_t batchSize, const std::vector< std::string > &cols, const std::string &filters="", const std::vector< std::size_t > &vecSizes={}, const float vecPadding=0.0, const float validationSplit=0.0, const std::size_t maxChunks=0, const std::size_t numColumns=0, bool shuffle=true)
std::unique_ptr< TMVA::Experimental::RTensor< float > > fCurrentBatch
std::unique_ptr< TMVA::Experimental::Internal::RChunkLoader< Args... > > fChunkLoader
std::unique_ptr< TMVA::Experimental::RTensor< float > > fChunkTensor
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
A TTree represents a columnar dataset.
Definition TTree.h:79
virtual Long64_t GetEntries() const
Definition TTree.h:463
create variable transformations