Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
LossFunctions.cu
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 13/07/16
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12///////////////////////////////////////////////////////////////////////
13// Implementation of the loss functions for the TCuda implementation //
14// of the low-level interface. //
15///////////////////////////////////////////////////////////////////////
16
19#include "Kernels.cuh"
20
21namespace TMVA
22{
23namespace DNN
24{
25
26//____________________________________________________________________________
27template<typename AFloat>
30 const TCudaMatrix<AFloat> & weights)
31{
32 dim3 blockDims = TDevice::BlockDims2D();
33 dim3 gridDims = TDevice::GridDims2D(Y);
34 cudaStream_t s = Y.GetComputeStream();
36 ::TMVA::DNN::Cuda::MeanSquaredError<<<gridDims, blockDims, 0, s>>>(
39 output.GetDataPointer(),
40 weights.GetDataPointer(),
41 (int) Y.GetNrows(),
42 (int) Y.GetNcols());
44}
45
46//____________________________________________________________________________
47template<typename AFloat>
49 const TCudaMatrix<AFloat> & Y,
51 const TCudaMatrix<AFloat> &weights)
52{
53 dim3 blockDims = TDevice::BlockDims2D();
54 dim3 gridDims = TDevice::GridDims2D(Y);
55 cudaStream_t s = output.GetComputeStream();
56 ::TMVA::DNN::Cuda::MeanSquaredErrorGradients<<<gridDims, blockDims, 0, s>>>(
57 dY.GetDataPointer(),
59 output.GetDataPointer(),
60 weights.GetDataPointer(),
61 (int) Y.GetNrows(),
62 (int) Y.GetNcols());
63 dY.SetComputeStream(s);
64}
65
66//____________________________________________________________________________
67template<typename AFloat>
70 const TCudaMatrix<AFloat> &weights)
71{
72 dim3 blockDims = TDevice::BlockDims2D();
73 dim3 gridDims = TDevice::GridDims2D(Y);
75 cudaStream_t s = Y.GetComputeStream();
76 ::TMVA::DNN::Cuda::CrossEntropy<<<gridDims, blockDims, 0, s>>>(
79 output.GetDataPointer(),
80 weights.GetDataPointer(),
81 (int) Y.GetNrows(),
82 (int) Y.GetNcols());
84}
85
86//____________________________________________________________________________
87template<typename AFloat>
89 const TCudaMatrix<AFloat> & Y,
91 const TCudaMatrix<AFloat> &weights)
92{
93 dim3 blockDims = TDevice::BlockDims2D();
94 dim3 gridDims = TDevice::GridDims2D(Y);
95 cudaStream_t s = output.GetComputeStream();
96 ::TMVA::DNN::Cuda::CrossEntropyGradients<<<gridDims, blockDims, 0, s>>>(
97 dY.GetDataPointer(),
99 output.GetDataPointer(),
100 weights.GetDataPointer(),
101 (int) Y.GetNrows(),
102 (int) Y.GetNcols());
103 dY.SetComputeStream(s);
104}
105
106//____________________________________________________________________________
107template<typename AFloat>
110 const TCudaMatrix<AFloat> &weights)
111{
112 dim3 blockDims = TDevice::BlockDims1D();
113 dim3 gridDims = TDevice::GridDims1D(Y);
115 cudaStream_t s = Y.GetComputeStream();
116 ::TMVA::DNN::Cuda::SoftmaxCrossEntropy<<<gridDims, blockDims, 0, s>>>(
118 Y.GetDataPointer(),
119 output.GetDataPointer(),
120 weights.GetDataPointer(),
121 (int) Y.GetNrows(),
122 (int) Y.GetNcols());
124}
125
126//____________________________________________________________________________
127template<typename AFloat>
129 const TCudaMatrix<AFloat> & Y,
131 const TCudaMatrix<AFloat> &weights)
132{
133 dim3 blockDims = TDevice::BlockDims1D();
134 dim3 gridDims = TDevice::GridDims1D(Y);
135 cudaStream_t s = output.GetComputeStream();
136 ::TMVA::DNN::Cuda::SoftmaxCrossEntropyGradients<<<gridDims, blockDims, 0, s>>>(
137 dY.GetDataPointer(),
138 Y.GetDataPointer(),
139 output.GetDataPointer(),
140 weights.GetDataPointer(),
141 (int) Y.GetNrows(),
142 (int) Y.GetNcols());
143 dY.SetComputeStream(s);
144}
145
146} // namespace DNN
147} // namespace TMVA
TCudaMatrix Class.
Definition CudaMatrix.h:103
size_t GetNcols() const
Definition CudaMatrix.h:160
static AFloat GetDeviceReturn()
Transfer the value in the device return buffer to the host.
Definition CudaMatrix.h:301
void SetComputeStream(cudaStream_t stream)
Definition CudaMatrix.h:275
cudaStream_t GetComputeStream() const
Definition CudaMatrix.h:268
static AFloat * GetDeviceReturnPointer()
Return device pointer to the device return buffer.
Definition CudaMatrix.h:151
static void ResetDeviceReturn(AFloat value=0.0)
Set the return buffer on the device to the specified value.
Definition CudaMatrix.h:293
const AFloat * GetDataPointer() const
Definition CudaMatrix.h:163
size_t GetNrows() const
Definition CudaMatrix.h:159
static Scalar_t CrossEntropy(const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
Sigmoid transformation is implicitly applied, thus output should hold the linear activations of the l...
static void MeanSquaredErrorGradients(Matrix_t &dY, const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
static Scalar_t MeanSquaredError(const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
static Scalar_t SoftmaxCrossEntropy(const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
Softmax transformation is implicitly applied, thus output should hold the linear activations of the l...
static void CrossEntropyGradients(Matrix_t &dY, const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
static void SoftmaxCrossEntropyGradients(Matrix_t &dY, const Matrix_t &Y, const Matrix_t &output, const Matrix_t &weights)
static dim3 BlockDims2D()
Definition Device.h:55
static dim3 GridDims2D(int nrows, int ncols)
Definition Device.h:74
static dim3 BlockDims1D()
Definition Device.h:48
static dim3 GridDims1D(const AMatrix &A)
Definition Device.h:63
create variable transformations
static void output()