Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA::DNN::TDataLoader< AData, TReference< AReal > > Class Template Reference

template<typename AData, typename AReal>
class TMVA::DNN::TDataLoader< AData, TReference< AReal > >

Definition at line 33 of file DataLoader.h.

Public Member Functions

 TDataLoader (const AData &data, size_t nSamples, size_t batchSize, size_t nInputFeatures, size_t nOutputFeatures, size_t nthreads=1)
 
 TDataLoader (const TDataLoader &)=default
 
 TDataLoader (TDataLoader &&)=default
 
BatchIterator_t begin ()
 
void CopyInput (TMatrixT< AReal > &matrix, IndexIterator_t begin)
 Copy input matrix into the given host buffer.
 
void CopyOutput (TMatrixT< AReal > &matrix, IndexIterator_t begin)
 Copy output matrix into the given host buffer.
 
void CopyWeights (TMatrixT< AReal > &matrix, IndexIterator_t begin)
 Copy weight matrix into the given host buffer.
 
BatchIterator_t end ()
 
TBatch< TReference< AReal > > GetBatch ()
 Return the next batch from the training set.
 
TDataLoaderoperator= (const TDataLoader &)=default
 
TDataLoaderoperator= (TDataLoader &&)=default
 
void Shuffle ()
 Shuffle the order of the samples in the batch.
 

Private Types

using BatchIterator_t = TBatchIterator< AData, TReference< AReal > >
 

Private Attributes

size_t fBatchIndex
 
size_t fBatchSize
 
const AData & fData
 
size_t fNInputFeatures
 
size_t fNOutputFeatures
 
size_t fNSamples
 
std::vector< size_t > fSampleIndices
 Ordering of the samples in the epoch.
 
TMatrixT< ARealinputMatrix
 
TMatrixT< ARealoutputMatrix
 
TMatrixT< ARealweightMatrix
 

#include <TMVA/DNN/Architectures/Reference/DataLoader.h>

Member Typedef Documentation

◆ BatchIterator_t

template<typename AData , typename AReal >
using TMVA::DNN::TDataLoader< AData, TReference< AReal > >::BatchIterator_t = TBatchIterator<AData, TReference<AReal> >
private

Definition at line 35 of file DataLoader.h.

Constructor & Destructor Documentation

◆ TDataLoader() [1/3]

template<typename AData , typename AReal >
TMVA::DNN::TDataLoader< AData, TReference< AReal > >::TDataLoader ( const AData &  data,
size_t  nSamples,
size_t  batchSize,
size_t  nInputFeatures,
size_t  nOutputFeatures,
size_t  nthreads = 1 
)

Definition at line 84 of file DataLoader.h.

◆ TDataLoader() [2/3]

template<typename AData , typename AReal >
TMVA::DNN::TDataLoader< AData, TReference< AReal > >::TDataLoader ( const TDataLoader< AData, TReference< AReal > > &  )
default

◆ TDataLoader() [3/3]

template<typename AData , typename AReal >
TMVA::DNN::TDataLoader< AData, TReference< AReal > >::TDataLoader ( TDataLoader< AData, TReference< AReal > > &&  )
default

Member Function Documentation

◆ begin()

template<typename AData , typename AReal >
BatchIterator_t TMVA::DNN::TDataLoader< AData, TReference< AReal > >::begin ( )
inline

Definition at line 69 of file DataLoader.h.

◆ CopyInput()

template<typename AData , typename AReal >
void TMVA::DNN::TDataLoader< AData, TReference< AReal > >::CopyInput ( TMatrixT< AReal > &  matrix,
IndexIterator_t  begin 
)

Copy input matrix into the given host buffer.

Function to be specialized by the architecture-specific backend.

◆ CopyOutput()

template<typename AData , typename AReal >
void TMVA::DNN::TDataLoader< AData, TReference< AReal > >::CopyOutput ( TMatrixT< AReal > &  matrix,
IndexIterator_t  begin 
)

Copy output matrix into the given host buffer.

Function to be specialized by the architecture-specific backend.

◆ CopyWeights()

template<typename AData , typename AReal >
void TMVA::DNN::TDataLoader< AData, TReference< AReal > >::CopyWeights ( TMatrixT< AReal > &  matrix,
IndexIterator_t  begin 
)

Copy weight matrix into the given host buffer.

Function to be specialized by the architecture-specific backend.

◆ end()

template<typename AData , typename AReal >
BatchIterator_t TMVA::DNN::TDataLoader< AData, TReference< AReal > >::end ( )
inline

Definition at line 70 of file DataLoader.h.

◆ GetBatch()

template<typename AData , typename AReal >
TBatch< TReference< AReal > > TMVA::DNN::TDataLoader< AData, TReference< AReal > >::GetBatch

Return the next batch from the training set.

The TDataLoader object keeps an internal counter that cycles over the batches in the training set.

Definition at line 97 of file DataLoader.h.

◆ operator=() [1/2]

template<typename AData , typename AReal >
TDataLoader & TMVA::DNN::TDataLoader< AData, TReference< AReal > >::operator= ( const TDataLoader< AData, TReference< AReal > > &  )
default

◆ operator=() [2/2]

template<typename AData , typename AReal >
TDataLoader & TMVA::DNN::TDataLoader< AData, TReference< AReal > >::operator= ( TDataLoader< AData, TReference< AReal > > &&  )
default

◆ Shuffle()

template<typename AData , typename AReal >
void TMVA::DNN::TDataLoader< AData, TReference< AReal > >::Shuffle

Shuffle the order of the samples in the batch.

The shuffling is indirect, i.e. only the indices are shuffled. No input data is moved by this routine.

Definition at line 115 of file DataLoader.h.

Member Data Documentation

◆ fBatchIndex

template<typename AData , typename AReal >
size_t TMVA::DNN::TDataLoader< AData, TReference< AReal > >::fBatchIndex
private

Definition at line 43 of file DataLoader.h.

◆ fBatchSize

template<typename AData , typename AReal >
size_t TMVA::DNN::TDataLoader< AData, TReference< AReal > >::fBatchSize
private

Definition at line 40 of file DataLoader.h.

◆ fData

template<typename AData , typename AReal >
const AData& TMVA::DNN::TDataLoader< AData, TReference< AReal > >::fData
private

Definition at line 37 of file DataLoader.h.

◆ fNInputFeatures

template<typename AData , typename AReal >
size_t TMVA::DNN::TDataLoader< AData, TReference< AReal > >::fNInputFeatures
private

Definition at line 41 of file DataLoader.h.

◆ fNOutputFeatures

template<typename AData , typename AReal >
size_t TMVA::DNN::TDataLoader< AData, TReference< AReal > >::fNOutputFeatures
private

Definition at line 42 of file DataLoader.h.

◆ fNSamples

template<typename AData , typename AReal >
size_t TMVA::DNN::TDataLoader< AData, TReference< AReal > >::fNSamples
private

Definition at line 39 of file DataLoader.h.

◆ fSampleIndices

template<typename AData , typename AReal >
std::vector<size_t> TMVA::DNN::TDataLoader< AData, TReference< AReal > >::fSampleIndices
private

Ordering of the samples in the epoch.

Definition at line 49 of file DataLoader.h.

◆ inputMatrix

template<typename AData , typename AReal >
TMatrixT<AReal> TMVA::DNN::TDataLoader< AData, TReference< AReal > >::inputMatrix
private

Definition at line 45 of file DataLoader.h.

◆ outputMatrix

template<typename AData , typename AReal >
TMatrixT<AReal> TMVA::DNN::TDataLoader< AData, TReference< AReal > >::outputMatrix
private

Definition at line 46 of file DataLoader.h.

◆ weightMatrix

template<typename AData , typename AReal >
TMatrixT<AReal> TMVA::DNN::TDataLoader< AData, TReference< AReal > >::weightMatrix
private

Definition at line 47 of file DataLoader.h.

  • tmva/tmva/inc/TMVA/DNN/Architectures/Reference/DataLoader.h