Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ActivationFunctions.cu
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 13/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 //////////////////////////////////////////////////////////////////
13 // Implementation of the activation functions for the TCuda //
14 // implementation of the low-level interface. //
15 //////////////////////////////////////////////////////////////////
16
19#include "TMVA/DNN/Functions.h"
20#include "Kernels.cuh"
21
22namespace TMVA
23{
24namespace DNN
25{
26//______________________________________________________________________________
27template<typename AFloat>
29 const ActivationDescriptor_t /* activationDescr */,
30 const double /* coef */, const AFloat /*alpha */, const AFloat /*beta*/)
31{
32 // scaling and translation is not yet implemented
33 TMVA::DNN::evaluate<TCuda<AFloat>>( X, activFunct);
34}
35//______________________________________________________________________________
36template<typename AFloat>
38 const Tensor_t & dY, const Tensor_t & X,
39 EActivationFunction activFunct,
40 const ActivationDescriptor_t /* activationDescr */,
41 const AFloat /* alpha */, const AFloat /* beta */)
42{
43 // scaling and translation not yet implemented
44 // output tensor (Y) could also be used to speed up derivative calculation
45 // compute dx = f'(x)
46 TMVA::DNN::evaluateDerivative<TCuda<AFloat>>(dX, activFunct, X);
47 // Compute element-wise product. dx = f'(x) * dY
48 Hadamard(dX, dY);
49}
50
51//______________________________________________________________________________
52template<typename AFloat>
54 const TCudaTensor<AFloat> & A)
55{
56 dim3 blockDims = TDevice::BlockDims2D();
57 dim3 gridDims = TDevice::GridDims2D(B);
58 cudaStream_t s = A.GetComputeStream();
59 ::TMVA::DNN::Cuda::IdentityDerivative<<<gridDims, blockDims, 0, s>>>(
61 (int) B.GetNrows(),
62 (int) B.GetNcols());
64}
65
66//______________________________________________________________________________
67template<typename AFloat>
69{
70 dim3 blockDims = TDevice::BlockDims2D();
71 dim3 gridDims = TDevice::GridDims2D(A);
72 cudaStream_t s = A.GetComputeStream();
73 ::TMVA::DNN::Cuda::Relu<<<gridDims, blockDims, 0, s>>>(
75 (int) A.GetNrows(),
76 (int) A.GetNcols());
77}
78
79//______________________________________________________________________________
80template<typename AFloat>
82 const TCudaTensor<AFloat> & A)
83{
84 assert(B.GetNrows() == A.GetNrows() && B.GetNcols() == A.GetNcols());
85 dim3 blockDims = TDevice::BlockDims2D();
86 dim3 gridDims = TDevice::GridDims2D(B);
87 cudaStream_t s = A.GetComputeStream();
88 ::TMVA::DNN::Cuda::ReluDerivative<<<gridDims, blockDims, 0, s>>>(
91 (int) A.GetNrows(),
92 (int) A.GetNcols());
94}
95
96//______________________________________________________________________________
97template<typename AFloat>
99{
100 dim3 blockDims = TDevice::BlockDims2D();
101 dim3 gridDims = TDevice::GridDims2D(A);
102 cudaStream_t s = A.GetComputeStream();
103 ::TMVA::DNN::Cuda::Sigmoid<<<gridDims, blockDims, 0, s>>>(
104 A.GetDataPointer(),
105 (int) A.GetNrows(),
106 (int) A.GetNcols());
107}
108
109//______________________________________________________________________________
110template<typename AFloat>
112 const TCudaTensor<AFloat> & A)
113{
114 assert(B.GetNrows() == A.GetNrows() && B.GetNcols() == A.GetNcols());
115 dim3 blockDims = TDevice::BlockDims2D();
116 dim3 gridDims = TDevice::GridDims2D(B);
117 cudaStream_t s = A.GetComputeStream();
118 ::TMVA::DNN::Cuda::SigmoidDerivative<<<gridDims, blockDims, 0, s>>>(
119 B.GetDataPointer(),
120 A.GetDataPointer(),
121 (int) A.GetNrows(),
122 (int) A.GetNcols());
123 B.SetComputeStream(s);
124}
125
126//______________________________________________________________________________
127template<typename AFloat>
129{
130 dim3 blockDims = TDevice::BlockDims2D();
131 dim3 gridDims = TDevice::GridDims2D(A);
132 cudaStream_t s = A.GetComputeStream();
133 ::TMVA::DNN::Cuda::Tanh<<<gridDims, blockDims, 0, s>>>(
134 A.GetDataPointer(),
135 (int) A.GetNrows(),
136 (int) A.GetNcols());
137}
138
139//______________________________________________________________________________
140template<typename AFloat>
142 const TCudaTensor<AFloat> & A)
143{
144 assert(B.GetNrows() == A.GetNrows() && B.GetNcols() == A.GetNcols());
145 dim3 blockDims = TDevice::BlockDims2D();
146 dim3 gridDims = TDevice::GridDims2D(B);
147 cudaStream_t s = A.GetComputeStream();
148 ::TMVA::DNN::Cuda::TanhDerivative<<<gridDims, blockDims, 0, s>>>(
149 B.GetDataPointer(),
150 A.GetDataPointer(),
151 (int) A.GetNrows(),
152 (int) A.GetNcols());
153 B.SetComputeStream(s);
154}
155
156//______________________________________________________________________________
157template<typename AFloat>
159{
160 dim3 blockDims = TDevice::BlockDims2D();
161 dim3 gridDims = TDevice::GridDims2D(A);
162 cudaStream_t s = A.GetComputeStream();
163 ::TMVA::DNN::Cuda::SymmetricRelu<<<gridDims, blockDims, 0, s>>>(
164 A.GetDataPointer(),
165 (int) A.GetNrows(),
166 (int) A.GetNcols());
167}
168
169//______________________________________________________________________________
170template<typename AFloat>
172 const TCudaTensor<AFloat> & A)
173{
174 assert(B.GetNrows() == A.GetNrows() && B.GetNcols() == A.GetNcols());
175 dim3 blockDims = TDevice::BlockDims2D();
176 dim3 gridDims = TDevice::GridDims2D(B);
177 cudaStream_t s = A.GetComputeStream();
178 ::TMVA::DNN::Cuda::SymmetricReluDerivative<<<gridDims, blockDims, 0, s>>>(
179 B.GetDataPointer(),
180 A.GetDataPointer(),
181 (int) A.GetNrows(),
182 (int) A.GetNcols());
183 B.SetComputeStream(s);
184}
185
186//______________________________________________________________________________
187template<typename AFloat>
189{
190 dim3 blockDims = TDevice::BlockDims2D();
191 dim3 gridDims = TDevice::GridDims2D(A);
192 cudaStream_t s = A.GetComputeStream();
193 ::TMVA::DNN::Cuda::SoftSign<<<gridDims, blockDims, 0, s>>>(
194 A.GetDataPointer(),
195 (int) A.GetNrows(),
196 (int) A.GetNcols());
197}
198
199//______________________________________________________________________________
200template<typename AFloat>
202 const TCudaTensor<AFloat> & A)
203{
204 assert(B.GetNrows() == A.GetNrows() && B.GetNcols() == A.GetNcols());
205 dim3 blockDims = TDevice::BlockDims2D();
206 dim3 gridDims = TDevice::GridDims2D(B);
207 cudaStream_t s = A.GetComputeStream();
208 ::TMVA::DNN::Cuda::SoftSignDerivative<<<gridDims, blockDims, 0, s>>>(
209 B.GetDataPointer(),
210 A.GetDataPointer(),
211 (int) A.GetNrows(),
212 (int) A.GetNcols());
213 B.SetComputeStream(s);
214}
215
216//______________________________________________________________________________
217template<typename AFloat>
219{
220 dim3 blockDims = TDevice::BlockDims2D();
221 dim3 gridDims = TDevice::GridDims2D(A);
222 cudaStream_t s = A.GetComputeStream();
223 ::TMVA::DNN::Cuda::Gauss<<<gridDims, blockDims, 0, s>>>(
224 A.GetDataPointer(),
225 (int) A.GetNrows(),
226 (int) A.GetNcols());
227}
228
229//______________________________________________________________________________
230template<typename AFloat>
232 const TCudaTensor<AFloat> & A)
233{
234 assert(B.GetNrows() == A.GetNrows() && B.GetNcols() == A.GetNcols());
235 dim3 blockDims = TDevice::BlockDims2D();
236 dim3 gridDims = TDevice::GridDims2D(B);
237 cudaStream_t s = A.GetComputeStream();
238 ::TMVA::DNN::Cuda::GaussDerivative<<<gridDims, blockDims, 0, s>>>(
239 B.GetDataPointer(),
240 A.GetDataPointer(),
241 (int) A.GetNrows(),
242 (int) A.GetNcols());
243 B.SetComputeStream(s);
244}
245
246} // namespace DNN
247} // namespace TMVA
TCudaTensor Class.
Definition CudaTensor.h:84
size_t GetNrows() const
Definition CudaTensor.h:299
cudaStream_t GetComputeStream() const
Definition CudaTensor.h:213
size_t GetNcols() const
Definition CudaTensor.h:300
const AFloat * GetDataPointer() const
Definition CudaTensor.h:194
void SetComputeStream(cudaStream_t stream)
Definition CudaTensor.h:216
static void SoftSignDerivative(Tensor_t &B, const Tensor_t &A)
static void SymmetricReluDerivative(Tensor_t &B, const Tensor_t &A)
static void IdentityDerivative(Tensor_t &B, const Tensor_t &A)
static void ActivationFunctionForward(Tensor_t &X, EActivationFunction activFunct, const ActivationDescriptor_t activationDescr, const double coef=0.0, const AFloat alpha=1, const AFloat beta=0)
static void SoftSign(Tensor_t &B)
static void Gauss(Tensor_t &B)
static void Sigmoid(Tensor_t &B)
static void Tanh(Tensor_t &B)
static void ActivationFunctionBackward(Tensor_t &dX, const Tensor_t &Y, const Tensor_t &dY, const Tensor_t &X, EActivationFunction activFunct, const ActivationDescriptor_t activationDescr, const AFloat alpha=1, const AFloat beta=0)
Computes the gradient of the activation function.
static void ReluDerivative(Tensor_t &B, const Tensor_t &A)
static void GaussDerivative(Tensor_t &B, const Tensor_t &A)
static void Relu(Tensor_t &B)
static void SymmetricRelu(Tensor_t &B)
static void SigmoidDerivative(Tensor_t &B, const Tensor_t &A)
static void TanhDerivative(Tensor_t &B, const Tensor_t &A)
static dim3 BlockDims2D()
Definition Device.h:55
static dim3 GridDims2D(int nrows, int ncols)
Definition Device.h:74
EActivationFunction
Enum that represents layer activation functions.
Definition Functions.h:32
create variable transformations