Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
DataLoader.h
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 06/06/17
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/////////////////////////////////////////////////////////////////////
13// Partial specialization of the TDataLoader class to adapt it to //
14// the TMatrix class. Also the data transfer is kept simple, since //
15// this implementation (being intended as reference and fallback //
16// is not optimized for performance. //
17/////////////////////////////////////////////////////////////////////
18
19#ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
20#define TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
21
22#include "TMVA/DNN/DataLoader.h"
23
24#include <random>
25
26namespace TMVA {
27namespace DNN {
28
29template <typename AReal>
30class TReference;
31
32template <typename AData, typename AReal>
34private:
36
37 const AData &fData;
38
39 size_t fNSamples;
40 size_t fBatchSize;
44
48
49 std::vector<size_t> fSampleIndices; ///< Ordering of the samples in the epoch.
50
51public:
52 TDataLoader(const AData &data, size_t nSamples, size_t batchSize, size_t nInputFeatures, size_t nOutputFeatures,
53 size_t nthreads = 1);
54 TDataLoader(const TDataLoader &) = default;
55 TDataLoader(TDataLoader &&) = default;
56 TDataLoader &operator=(const TDataLoader &) = default;
58
59 /** Copy input matrix into the given host buffer. Function to be specialized by
60 * the architecture-specific backend. */
62 /** Copy output matrix into the given host buffer. Function to be specialized
63 * by the architecture-specific backend. */
65 /** Copy weight matrix into the given host buffer. Function to be specialized
66 * by the architecture-specific backend. */
68
71
72 /** Shuffle the order of the samples in the batch. The shuffling is indirect,
73 * i.e. only the indices are shuffled. No input data is moved by this
74 * routine. */
75 void Shuffle();
76
77 /** Return the next batch from the training set. The TDataLoader object
78 * keeps an internal counter that cycles over the batches in the training
79 * set. */
81};
82
83template <typename AData, typename AReal>
84TDataLoader<AData, TReference<AReal>>::TDataLoader(const AData &data, size_t nSamples, size_t batchSize,
85 size_t nInputFeatures, size_t nOutputFeatures, size_t /*nthreads*/)
86 : fData(data), fNSamples(nSamples), fBatchSize(batchSize), fNInputFeatures(nInputFeatures),
87 fNOutputFeatures(nOutputFeatures), fBatchIndex(0), inputMatrix(batchSize, nInputFeatures),
88 outputMatrix(batchSize, nOutputFeatures), weightMatrix(batchSize, 1), fSampleIndices()
89{
91 for (size_t i = 0; i < fNSamples; i++) {
92 fSampleIndices.push_back(i);
93 }
94}
95
96template <typename AData, typename AReal>
98{
99 fBatchIndex %= (fNSamples / fBatchSize); // Cycle through samples.
100
101 size_t sampleIndex = fBatchIndex * fBatchSize;
102 IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
103
104 CopyInput(inputMatrix, sampleIndexIterator);
105 CopyOutput(outputMatrix, sampleIndexIterator);
106 CopyWeights(weightMatrix, sampleIndexIterator);
107
108 fBatchIndex++;
109
110 return TBatch<TReference<AReal>>(inputMatrix, outputMatrix, weightMatrix);
111}
112
113//______________________________________________________________________________
114template <typename AData, typename AReal>
116{
117 std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), std::default_random_engine{});
118}
119
120} // namespace DNN
121} // namespace TMVA
122
123#endif // TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
void CopyWeights(TMatrixT< AReal > &matrix, IndexIterator_t begin)
Copy weight matrix into the given host buffer.
TDataLoader & operator=(const TDataLoader &)=default
void CopyOutput(TMatrixT< AReal > &matrix, IndexIterator_t begin)
Copy output matrix into the given host buffer.
void CopyInput(TMatrixT< AReal > &matrix, IndexIterator_t begin)
Copy input matrix into the given host buffer.
TDataLoader & operator=(TDataLoader &&)=default
std::vector< size_t > fSampleIndices
Ordering of the samples in the epoch.
Definition DataLoader.h:49
TBatchIterator< Data_t, AArchitecture > BatchIterator_t
Definition DataLoader.h:135
BatchIterator_t begin()
Definition DataLoader.h:170
std::vector< size_t > fSampleIndices
Ordering of the samples in the epoch.
Definition DataLoader.h:149
TBatch< AArchitecture > GetBatch()
Return the next batch from the training set.
Definition DataLoader.h:228
void Shuffle()
Shuffle the order of the samples in the batch.
Definition DataLoader.h:269
The reference architecture class.
Definition Reference.h:53
TMatrixT.
Definition TMatrixT.h:39
typename std::vector< size_t >::iterator IndexIterator_t
Definition DataLoader.h:42
create variable transformations