Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
OutputFunctions.cu
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 11/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12////////////////////////////////////////////////////////////////
13// Explicit instantiation of the Reference architecture class //
14// template for Double_t scalar types. //
15////////////////////////////////////////////////////////////////
16
19#include "Kernels.cuh"
20
21namespace TMVA
22{
23namespace DNN
24{
25
26template<typename AFloat>
27void TCuda<AFloat>::Sigmoid(TCudaMatrix<AFloat> & B,
28 const TCudaMatrix<AFloat> & A)
29{
30 dim3 blockDims = TDevice::BlockDims2D();
31 dim3 gridDims = TDevice::GridDims2D(B);
32 cudaStream_t s = A.GetComputeStream();
33 ::TMVA::DNN::Cuda::Sigmoid<<<gridDims, blockDims, 0, s>>>(B.GetDataPointer(),
34 A.GetDataPointer(),
35 (int) A.GetNrows(),
36 (int) A.GetNcols());
37 B.SetComputeStream(s);
38}
39
40//______________________________________________________________________________
41template<typename AFloat>
43 const TCudaMatrix<AFloat> & A)
44{
45 dim3 blockDims = TDevice::BlockDims1D();
46 dim3 gridDims = TDevice::GridDims1D(B);
47 cudaStream_t s = A.GetComputeStream();
48 ::TMVA::DNN::Cuda::Softmax<<<gridDims, blockDims, 0, s>>>(B.GetDataPointer(),
50 (int) A.GetNrows(),
51 (int) A.GetNcols());
53}
54
55} // namespace DNN
56} // namespace TMVA
TCudaMatrix Class.
Definition CudaMatrix.h:103
size_t GetNcols() const
Definition CudaMatrix.h:160
void SetComputeStream(cudaStream_t stream)
Definition CudaMatrix.h:275
cudaStream_t GetComputeStream() const
Definition CudaMatrix.h:268
const AFloat * GetDataPointer() const
Definition CudaMatrix.h:163
size_t GetNrows() const
Definition CudaMatrix.h:159
static void Sigmoid(Tensor_t &B)
static void Softmax(Matrix_t &YHat, const Matrix_t &)
static dim3 BlockDims2D()
Definition Device.h:55
static dim3 GridDims2D(int nrows, int ncols)
Definition Device.h:74
static dim3 BlockDims1D()
Definition Device.h:48
static dim3 GridDims1D(const AMatrix &A)
Definition Device.h:63
create variable transformations