Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RecurrentPropagation.cu
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Lorenzo Moneta 2020
3
4/*************************************************************************
5 * Copyright (C) 2017, Saurav Shekhar *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12 //////////////////////////////////////////////////////////////////
13 // Implementation of the functions required for the forward and //
14 // backward propagation of activations through a recurrent neural network //
15 // for CUDA architectures. //
16 //////////////////////////////////////////////////////////////////
17
19
20namespace TMVA
21{
22namespace DNN
23{
24template <typename AFloat>
25template <typename RNNLayer>
26void TCudnn<AFloat>::InitializeRecurrentTensors(RNNLayer *layer)
27{
28 // initialization of the RNN tensors for setting the right layout (ROW major)
29 size_t timeSteps = (layer->DoesReturnSequence()) ? layer->GetTimeSteps() : 1;
30 layer->GetOutput() =
31 Tensor_t(layer->GetOutput().GetDeviceBuffer(),
32 {layer->GetBatchSize(), timeSteps, layer->GetStateSize()}, GetTensorLayout());
33 layer->GetActivationGradients() =
34 Tensor_t(layer->GetActivationGradients().GetDeviceBuffer(), {layer->GetBatchSize(), timeSteps, layer->GetStateSize()},
35 GetTensorLayout());
36
37 // make the weight tensors in the right layout (Row-major)
38 for (size_t i = 0; i < layer->GetWeights().size(); ++i) {
39 auto &w = layer->GetWeightsAt(i);
40
41 w = Tensor_t(layer->GetWeightsAt(i).GetDeviceBuffer(), {layer->GetWeightsAt(i).GetNrows(), layer->GetWeightsAt(i).GetNcols()},
42 GetTensorLayout());
43 }
44 // now the biases
45 for (size_t i = 0; i < layer->GetBiases().size(); ++i) {
46
47 // reshape tensors
48 auto &b = layer->GetBiasesAt(i);
49 b = Tensor_t(layer->GetBiasesAt(i).GetDeviceBuffer(), {layer->GetStateSize(), 1}, GetTensorLayout(), 0, 0);
50
51 }
52
53 // layer->GetWeightsState() = Tensor_t(layer->GetWeightsState().GetDeviceBuffer(),
54 // {layer->GetStateSize(), layer->GetStateSize()}, GetTensorLayout());
55 // layer->GetWeightsInput() = Tensor_t(layer->GetWeightsInput().GetDeviceBuffer(),
56 // {layer->GetStateSize(), layer->GetInputSize()}, GetTensorLayout());
57 // layer->GetBiasesState() = Tensor_t(layer->GetBiasesState().GetDeviceBuffer(),
58 // {layer->GetStateSize(), 1 }, GetTensorLayout());
59
60 layer->GetX() = Tensor_t({layer->GetTimeSteps(), layer->GetBatchSize(), layer->GetInputSize() }, GetTensorLayout());
61 layer->GetY() = Tensor_t({layer->GetTimeSteps(), layer->GetBatchSize(), layer->GetStateSize() }, GetTensorLayout());
62
63 layer->GetDX() = Tensor_t({layer->GetTimeSteps(), layer->GetBatchSize(), layer->GetInputSize() }, GetTensorLayout());
64 layer->GetDY() = Tensor_t({layer->GetTimeSteps(), layer->GetBatchSize(), layer->GetStateSize() }, GetTensorLayout());
65}
66//____________________________________________________________________________
67template <typename AFloat>
68template <typename RNNLayer>
69void TCudnn<AFloat>::InitializeRecurrentDescriptors(TDescriptors *&descriptors, RNNLayer *layer)
70{
71
72 auto rnnDescriptors = new RNNDescriptors_t ();
73 CUDNNCHECK(cudnnCreateRNNDescriptor(&rnnDescriptors->LayerDescriptor));
74
75 CUDNNCHECK(cudnnCreateDropoutDescriptor(&rnnDescriptors->HelperDescriptor));
76
77 enum RNNType {kRNN, kLSTM, kGRU};
78 RNNType rnn_type = kRNN;
79 if ( std::is_same<RNNLayer, LSTMLayer_t>::value ) rnn_type = kLSTM;
80 if ( std::is_same<RNNLayer, GRULayer_t>::value ) rnn_type = kGRU;
81
82 cudnnHandle_t handle = layer->GetOutput().GetCudnnHandle();
83 float dropoutProb = 0.0; // layer->GetDroputProbability();
84
85 void *dropoutStates = nullptr; // random generator states ??
86 size_t dropoutStateSize = 0;
87
88 // get size of droput states
89 CUDNNCHECK(cudnnDropoutGetStatesSize(handle, &dropoutStateSize));
90
91 //unsigned long long seed = GetRandomGenerator().Integer(INT_MAX);
92 // use GetSeed to avoid generating other numbers which will break sequence
93 unsigned long long seed = GetRandomGenerator().GetSeed();
94
95 CUDNNCHECK(cudnnSetDropoutDescriptor(rnnDescriptors->HelperDescriptor, handle, dropoutProb, dropoutStates,
96 dropoutStateSize, seed));
97 // cudnnDropoutDescriptor_t dropoutDesc,
98 // cudnnHandle_t handle,
99 // float dropout,
100 // void *states,
101 // size_t stateSizeInBytes,
102 // unsigned long long seed)
103
104 int hiddenSize = layer->GetStateSize();
105 int numLayers = 1; // this is not time steps is for stacked layers // layer->GetTimeSteps();
106 //cudnnRNNInputMode_t inputMode = CUDNN_SKIP_INPUT; // the leasing dimension of x must be equal to hiddenSize
107 cudnnRNNInputMode_t inputMode = CUDNN_LINEAR_INPUT; // this a vanilla rnn
108
109 cudnnDirectionMode_t direction = CUDNN_UNIDIRECTIONAL; // can be CUDNN_BIDIRECTIONAL
110 bool bidirectional = (direction == CUDNN_BIDIRECTIONAL);
111
112 cudnnRNNMode_t mode = CUDNN_RNN_TANH; // can be CUDNN_RNN_RELU, CUDNN_LSTM, CUDNN_GRU
113 if (rnn_type == kLSTM) mode = CUDNN_LSTM; // lstm case
114 if (rnn_type == kGRU) mode = CUDNN_GRU;
115
116 cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; // can be also CUDNN_RNN_ALGO_PERSIST_STATIC or CUDNN_RNN_ALGO_PERSIST_DYNAMIC
117
118 // this identifies the weights matrices
119 int numLinearLayers = 0;
120 if (mode == CUDNN_RNN_RELU || mode == CUDNN_RNN_TANH) {
121 numLinearLayers = 2;
122 }
123 if (mode == CUDNN_GRU ) {
124 numLinearLayers = 6;
125 }
126 if (mode == CUDNN_LSTM) {
127 numLinearLayers = 8;
128 }
129 // this should be the size of the weights vector
130 assert(numLinearLayers == layer->GetWeights().size());
131
132 cudnnDataType_t mathPrec = CUDNN_DATA_FLOAT;
133 if (std::is_same<AFloat, double>::value) { mathPrec = CUDNN_DATA_DOUBLE;}
134
135#if (CUDNN_VERSION >= 8000)
136 CUDNNCHECK(cudnnSetRNNDescriptor_v6(handle, rnnDescriptors->LayerDescriptor, hiddenSize, numLayers, rnnDescriptors->HelperDescriptor,
137#else
138 CUDNNCHECK(cudnnSetRNNDescriptor(handle, rnnDescriptors->LayerDescriptor, hiddenSize, numLayers, rnnDescriptors->HelperDescriptor,
139#endif
140 inputMode, direction, mode, algo, mathPrec) );
141
142
143 // set bias mode
144 cudnnRNNBiasMode_t biasMode = CUDNN_RNN_NO_BIAS;
145 if (layer->GetBiases().size() > 0)
146 biasMode = CUDNN_RNN_SINGLE_INP_BIAS;
147 //biasMode = CUDNN_RNN_REC_BIAS; // difference is only for GRU
148
149 CUDNNCHECK(cudnnSetRNNBiasMode(rnnDescriptors->LayerDescriptor, biasMode));
150
151 // define tensor descriptors for RNN
152
153 int dimA[3];
154 int strideA[3];
155 int seqLength = layer->GetTimeSteps();
156
157 rnnDescriptors->xDesc.resize(seqLength);
158 rnnDescriptors->yDesc.resize(seqLength);
159 rnnDescriptors->dxDesc.resize(seqLength);
160 rnnDescriptors->dyDesc.resize(seqLength);
161 TensorDescriptor_t *xDesc = rnnDescriptors->xDesc.data();
162 TensorDescriptor_t *yDesc = rnnDescriptors->yDesc.data();
163 TensorDescriptor_t *dxDesc = rnnDescriptors->dxDesc.data();
164 TensorDescriptor_t *dyDesc = rnnDescriptors->dyDesc.data();
165
166 for (int i = 0; i < seqLength; i++) {
167 CUDNNCHECK(cudnnCreateTensorDescriptor(&xDesc[i]));
168 CUDNNCHECK(cudnnCreateTensorDescriptor(&yDesc[i]));
169 CUDNNCHECK(cudnnCreateTensorDescriptor(&dxDesc[i]));
170 CUDNNCHECK(cudnnCreateTensorDescriptor(&dyDesc[i]));
171
172 dimA[0] = layer->GetBatchSize();
173 dimA[1] = layer->GetInputSize();
174 dimA[2] = 1;
175
176 strideA[0] = dimA[2] * dimA[1];
177 strideA[1] = dimA[2];
178 strideA[2] = 1;
179
180 CUDNNCHECK(cudnnSetTensorNdDescriptor(xDesc[i], mathPrec, 3, dimA, strideA));
181 CUDNNCHECK(cudnnSetTensorNdDescriptor(dxDesc[i], mathPrec, 3, dimA, strideA));
182
183 dimA[0] = layer->GetBatchSize();
184 dimA[1] = bidirectional ? hiddenSize * 2 : hiddenSize;
185 dimA[2] = 1;
186
187 strideA[0] = dimA[2] * dimA[1];
188 strideA[1] = dimA[2];
189 strideA[2] = 1;
190
191 CUDNNCHECK(cudnnSetTensorNdDescriptor(yDesc[i], mathPrec, 3, dimA, strideA));
192 CUDNNCHECK(cudnnSetTensorNdDescriptor(dyDesc[i], mathPrec, 3, dimA, strideA));
193 }
194
195 // weight descriptor
196 CUDNNCHECK(cudnnCreateFilterDescriptor(&rnnDescriptors->WeightsDescriptor));
197 CUDNNCHECK(cudnnCreateFilterDescriptor(&rnnDescriptors->WeightsGradDescriptor));
198
199 // Set the filter parameters
200 size_t weightsSize = 0;
201 CUDNNCHECK(cudnnGetRNNParamsSize(handle, rnnDescriptors->LayerDescriptor, xDesc[0], &weightsSize, mathPrec));
202
203 int dimW[3];
204 dimW[0] = (mathPrec == CUDNN_DATA_DOUBLE) ? weightsSize / sizeof(double) : weightsSize / sizeof(float);
205 dimW[1] = 1;
206 dimW[2] = 1;
207
208 CUDNNCHECK(cudnnSetFilterNdDescriptor(rnnDescriptors->WeightsDescriptor, mathPrec, CUDNN_TENSOR_NCHW, 3, dimW));
209 CUDNNCHECK(cudnnSetFilterNdDescriptor(rnnDescriptors->WeightsGradDescriptor, mathPrec, CUDNN_TENSOR_NCHW, 3, dimW));
210
211 // resize now weights tensor
212 auto &weightTensor = layer->GetWeightsTensor();
213 auto &weightGradTensor = layer->GetWeightGradientsTensor();
214
215 weightTensor = Tensor_t( { (size_t) dimW[0], 1, 1}, GetTensorLayout(), 0, 0);
216 weightGradTensor = Tensor_t({(size_t) dimW[0], 1, 1}, GetTensorLayout(), 0, 0);
217
218 // initialize now RNN weights from RNNLayer:WeightInput, RNNLayer::WeightState and RNNLayer::BiasesState
219
220 // support now only one single layer and not bidirectional
221 int nL = (!bidirectional) ? numLayers : 2 * numLayers; // for bidirectional nL = 2 * numLayers;
222 for (int ilayer = 0; ilayer < nL; ilayer++) {
223 for (int linLayerID = 0; linLayerID < numLinearLayers; linLayerID++) {
224 cudnnFilterDescriptor_t linLayerMatDesc;
225 CUDNNCHECK(cudnnCreateFilterDescriptor(&linLayerMatDesc));
226 AFloat *linLayerMat;
227
228 CUDNNCHECK(cudnnGetRNNLinLayerMatrixParams(handle, rnnDescriptors->LayerDescriptor, ilayer, rnnDescriptors->xDesc.data()[0],
229 rnnDescriptors->WeightsDescriptor, weightTensor.GetDataPointer(),
230 linLayerID, linLayerMatDesc, (void **)&linLayerMat));
231
232 cudnnDataType_t dataType;
233 cudnnTensorFormat_t format;
234 int nbDims;
235 int filterDimA[3];
236 CUDNNCHECK(cudnnGetFilterNdDescriptor(linLayerMatDesc, 3, &dataType, &format, &nbDims, filterDimA));
237
238 /// RNN: linLayerID = 0 : input weight
239 // = 1 : input state
240 //
241 // LSTM = 0,4 : input gate ( weight input + weight state)
242 // = 1,5 : forget gate weight
243 // = 2, 6 : new memory gate weight
244 // = 3, 7 : output gate
245 //
246 // fortunatly same convention is used in the RNNLayers::GetWeights()[ID]
247
248 // copy layer weights in linLayerMat
249 // if (linLayerID == 0)
250 // {
251 // copy from GetStateWeights (tensor is state x state)
252 int wsize = layer->GetWeightsAt(linLayerID).GetSize();
253
254 // std::cout << "input weight size = " << wsize << " { " << layer->GetWeightsInput().GetNrows() << " "
255 // << layer->GetWeightsInput().GetNcols() << "} should be " << filterDimA[1] << " x "
256 // << filterDimA[2] << std::endl;
257
258
259 assert(wsize == filterDimA[1] * filterDimA[2]);
260 cudaMemcpyAsync(linLayerMat, layer->GetWeightsAt(linLayerID).GetDataPointer(), wsize * sizeof(AFloat),
261 cudaMemcpyDeviceToDevice, layer->GetWeightsAt(linLayerID).GetComputeStream());
262
263
264 CUDNNCHECK(cudnnDestroyFilterDescriptor(linLayerMatDesc));
265
266 cudnnFilterDescriptor_t linLayerBiasDesc;
267 CUDNNCHECK(cudnnCreateFilterDescriptor(&linLayerBiasDesc));
268 AFloat *linLayerBias;
269
270 CUDNNCHECK(cudnnGetRNNLinLayerBiasParams(handle, rnnDescriptors->LayerDescriptor, ilayer,
271 rnnDescriptors->xDesc.data()[0], rnnDescriptors->WeightsDescriptor,
272 weightTensor.GetDataPointer(), linLayerID, linLayerBiasDesc,
273 (void **)&linLayerBias));
274
275 CUDNNCHECK(cudnnGetFilterNdDescriptor(linLayerBiasDesc, 3, &dataType, &format, &nbDims, filterDimA));
276
277 // Here for the bias : standard is input bias mode
278
279 // linLayerID = 0 (RNN) 0,1,2,3 LSTM 0,1,2 GRU if CUDNN_RNN_SINGLE_INP_BIAS mode
280 int biasID = linLayerID;
281 if (biasMode == CUDNN_RNN_SINGLE_REC_BIAS) {
282 // case of state bias
283 //linLayerID = 1 (RNN), (4,5,6,7) LSTM , (3,4,5) GRU
284 biasID = linLayerID - 1;
285 if (mode == CUDNN_LSTM) biasID = linLayerID - 4;
286 if (mode == CUDNN_GRU) biasID = linLayerID - 3;
287 }
288
289 if (filterDimA[0] > 0) {
290
291 // check if above definitions are valid
292 assert(biasID >= 0);
293
294 // copy from GetStateWeights (tensor is state x state)
295 int wsize = layer->GetBiasesAt(biasID).GetSize();
296
297 // std::cout << "state bias " << wsize << " bias ID " << biasID << " { " <<
298 // layer->GetBiasesAt(biasID).GetNrows() << " "
299 // << layer->GetBiasesAt(biasID).GetNcols() << "} should be " << filterDimA[1] << " x " <<
300 // filterDimA[2]
301 // << std::endl;
302
303 // PrintTensor(layer->GetBiasesState(), "Bias state");
304
305 assert(wsize == filterDimA[1]);
306 cudaMemcpyAsync(linLayerBias, layer->GetBiasesAt(biasID).GetDataPointer(), wsize * sizeof(AFloat),
307 cudaMemcpyDeviceToDevice, layer->GetBiasesAt(biasID).GetComputeStream());
308
309 // PrintTensor(weightTensor, "After biasW WeightTensor");
310 }
311
312 CUDNNCHECK(cudnnGetFilterNdDescriptor(linLayerBiasDesc, 3, &dataType, &format, &nbDims, filterDimA));
313
314 // initGPUData(linLayerBias, filterDimA[0] * filterDimA[1] * filterDimA[2], 1.f);
315
316 CUDNNCHECK(cudnnDestroyFilterDescriptor(linLayerBiasDesc));
317 }
318
319 }
320
321 //PrintTensor(weightTensor, "Full WeightTensor");
322
323 // the weight tensor in Cudnn is stored as
324 // weights input + weights state + bias state
325
326 size_t offset = 0;
327 for (size_t i = 0; i < layer->GetWeights().size(); ++i) {
328 auto &w = layer->GetWeightsAt(i);
329 auto & dw = layer->GetWeightGradientsAt(i);
330 assert(weightTensor(offset, 0, 0) == w(0, 0));
331
332 // reshape tensors
333 w = Tensor_t(weightTensor.GetDeviceBuffer().GetSubBuffer(offset, w.GetSize()), w.GetShape(),
334 GetTensorLayout(), 0, 0);
335 dw = Tensor_t(weightGradTensor.GetDeviceBuffer().GetSubBuffer(offset, w.GetSize()), w.GetShape(), GetTensorLayout(), 0, 0);
336
337 offset += w.GetSize();
338 }
339 // now the biases
340 for (size_t i = 0; i < layer->GetBiases().size(); ++i) {
341 auto &b = layer->GetBiasesAt(i);
342 auto &db = layer->GetBiasGradientsAt(i);
343 assert(weightTensor(offset, 0, 0) == b(0, 0));
344
345 // reshape tensors
346 b = Tensor_t(weightTensor.GetDeviceBuffer().GetSubBuffer(offset, b.GetSize()), b.GetShape(), GetTensorLayout(), 0, 0);
347 db = Tensor_t(weightGradTensor.GetDeviceBuffer().GetSubBuffer(offset, b.GetSize()), b.GetShape(), GetTensorLayout(), 0,
348 0);
349
350 offset += b.GetSize();
351 }
352
353 // auto &weightsInput = layer->GetWeightsInput();
354 // auto &weightsState = layer->GetWeightsState();
355 // auto &biasesState = layer->GetBiasesState();
356
357 // auto &weightInputGrad = layer->GetWeightInputGradients();
358 // auto &weightStateGrad = layer->GetWeightStateGradients();
359 // auto &biasStateGrad = layer->GetBiasStateGradients();
360
361 // size_t offset_state = weightsInput.GetSize();
362 // size_t offset_bias_state = offset_state + weightsState.GetSize();
363
364 // assert(weightTensor(0,0,0) == weightsInput(0,0));
365 // assert(weightTensor(offset_state,0,0) == weightsState(0,0));
366 // assert(weightTensor(offset_bias_state,0,0) == biasesState(0,0));
367
368 // // now we set the right buffers for the tensor weights and gradients
369 // weightsInput = Tensor_t(weightTensor.GetDeviceBuffer().GetSubBuffer(0, weightsInput.GetSize()),
370 // weightsInput.GetShape(), GetTensorLayout(), 0, 0);
371 // weightsState = Tensor_t(weightTensor.GetDeviceBuffer().GetSubBuffer(offset_state, weightsState.GetSize()),
372 // weightsState.GetShape(), GetTensorLayout(), 0, 0);
373 // biasesState = Tensor_t(weightTensor.GetDeviceBuffer().GetSubBuffer(offset_bias_state, biasesState.GetSize()),
374 // biasesState.GetShape(), GetTensorLayout(), 0, 0);
375
376 // weightInputGrad = Tensor_t(weightGradTensor.GetDeviceBuffer().GetSubBuffer(0, weightInputGrad.GetSize()),
377 // weightInputGrad.GetShape(), GetTensorLayout(), 0, 0);
378 // weightStateGrad =
379 // Tensor_t(weightGradTensor.GetDeviceBuffer().GetSubBuffer(offset_state, weightStateGrad.GetSize()),
380 // weightStateGrad.GetShape(), GetTensorLayout(), 0, 0);
381 // biasStateGrad =
382 // Tensor_t(weightGradTensor.GetDeviceBuffer().GetSubBuffer(offset_bias_state, biasStateGrad.GetSize()),
383 // biasStateGrad.GetShape(), GetTensorLayout(), 0, 0);
384
385
386
387 descriptors = rnnDescriptors;
388}
389
390//____________________________________________________________________________
391template<typename AFloat>
392void TCudnn<AFloat>::ReleaseRNNDescriptors(TDescriptors * descriptors)
393{
394 auto rnnDescriptors = static_cast<RNNDescriptors_t *>(descriptors);
395 CUDNNCHECK(cudnnDestroyRNNDescriptor(rnnDescriptors->LayerDescriptor));
396
397 ReleaseDescriptor(rnnDescriptors->HelperDescriptor);
398 ReleaseDescriptor(rnnDescriptors->WeightsDescriptor);
399 ReleaseDescriptor(rnnDescriptors->WeightsGradDescriptor);
400
401 // need to delete the vectors of tensor descriptors
402 for (size_t i = 0; i < rnnDescriptors->xDesc.size(); i++) {
403 cudnnDestroyTensorDescriptor(rnnDescriptors->xDesc.data()[i]);
404 cudnnDestroyTensorDescriptor(rnnDescriptors->yDesc.data()[i]);
405
406 cudnnDestroyTensorDescriptor(rnnDescriptors->dxDesc.data()[i]);
407 cudnnDestroyTensorDescriptor(rnnDescriptors->dyDesc.data()[i]);
408 }
409
410}
411
412
413//____________________________________________________________________________
414template <typename AFloat>
415template <typename RNNLayer>
416void TCudnn<AFloat>::InitializeRecurrentWorkspace(TWorkspace *&workspace, TDescriptors *&descriptors, RNNLayer *layer)
417{
418 auto rnnWorkspace = new RNNWorkspace_t ();
419 auto rnnDescriptors = static_cast<RNNDescriptors_t *>(descriptors);
420
421 cudnnHandle_t handle = layer->GetOutput().GetCudnnHandle();
422
423 bool bidirectional = false;
424
425 size_t numLayers = 1; // support now only one single layer
426 if (bidirectional) numLayers *= 2; // bidirectional RNN is like having two layers
427
428 // redefine shape of layer->GetShape
429 Tensor_t &stateTensor = layer->GetState();
430 stateTensor = Tensor_t(stateTensor.GetDeviceBuffer(), { numLayers, layer->GetBatchSize(), layer->GetStateSize()},
431 GetTensorLayout(), 0, 0 );
432
433 if (layer->GetCell().GetSize() > 0) { // in case of LSTM
434 Tensor_t & cellStateTensor = layer->GetCell();
435 cellStateTensor = Tensor_t(cellStateTensor.GetDeviceBuffer(), {numLayers, layer->GetBatchSize(), layer->GetStateSize()}, GetTensorLayout(), 0, 0 );
436 }
437
438
439 // get workspace size
440
441 // need to fill xDesc with input tensor descriptors for each layer
442 CUDNNCHECK(cudnnGetRNNWorkspaceSize(handle, rnnDescriptors->LayerDescriptor, layer->GetTimeSteps(),
443 rnnDescriptors->xDesc.data(), &rnnWorkspace->ForwardWorkspaceSize));
444
445 if (rnnWorkspace->ForwardWorkspaceSize) cudaMalloc(&rnnWorkspace->ForwardWorkspace, rnnWorkspace->ForwardWorkspaceSize*sizeof(AFloat));
446 if (rnnWorkspace->ForwardWorkspaceSize > 0 && rnnWorkspace->ForwardWorkspace == nullptr ) {
447 std::cerr << "Error allocating RNN workspace of size " << rnnWorkspace->ForwardWorkspaceSize << " - probably running out of memory on the GPU"
448 << std::endl;
449 std::cout << " layer input shape is { " << layer->GetBatchSize() << " , " << layer->GetTimeSteps() << " , "
450 <<layer->GetStateSize() << " } " << std::endl;
451
452 R__ASSERT(false);
453 }
454
455 CUDNNCHECK(cudnnGetRNNTrainingReserveSize(handle, rnnDescriptors->LayerDescriptor, layer->GetTimeSteps(),
456 rnnDescriptors->xDesc.data(), &rnnWorkspace->HelperWorkspaceSize));
457
458 if (rnnWorkspace->HelperWorkspaceSize) cudaMalloc(&rnnWorkspace->HelperWorkspace, rnnWorkspace->HelperWorkspaceSize*sizeof(AFloat));
459 if (rnnWorkspace->HelperWorkspaceSize > 0 && rnnWorkspace->HelperWorkspace == nullptr ) {
460 std::cerr << "Error allocating RNN reserved workspace of size " << rnnWorkspace->HelperWorkspaceSize << " - probably running out of memory on the GPU"
461 << std::endl;
462 std::cout << " layer input shape is { " << layer->GetBatchSize() << " , " << layer->GetTimeSteps() << " , "
463 <<layer->GetStateSize() << " } " << std::endl;
464
465 R__ASSERT(false);
466 }
467 workspace = rnnWorkspace;
468}
469
470//____________________________________________________________________________
471template <typename AFloat>
472void TCudnn<AFloat>::FreeRNNWorkspace(TWorkspace * workspace) {
473 if (!workspace) return;
474 auto rnnWorkspace = static_cast<RNNWorkspace_t *>(workspace);
475
476 if(rnnWorkspace->ForwardWorkspace) cudaFree(rnnWorkspace->ForwardWorkspace);
477 if(rnnWorkspace->HelperWorkspace) cudaFree(rnnWorkspace->HelperWorkspace);
478
479
480}
481
482//____________________________________________________________________________
483template <typename AFloat>
484void TCudnn<AFloat>::RNNForward(const Tensor_t &x, const Tensor_t &hx, const Tensor_t &cx, const Tensor_t & weights, Tensor_t &y,
485 Tensor_t &hy, Tensor_t &cy, const RNNDescriptors_t & desc, RNNWorkspace_t &workspace, bool isTraining)
486
487{
488
489 bool rememberState = false;
490 cudnnHandle_t cudnnHandle = x.GetCudnnHandle();
491
492 int seqLength = x.GetShape()[0]; // time steps
493 cudnnRNNDescriptor_t rnnDesc = desc.LayerDescriptor;
494
495 // initial state and cell state will be set to zero
496 bool isLSTM = (cx.GetSize() > 0) && rememberState;
497
498 // Perform forward training
499 if (isTraining) {
500 cudnnStatus_t status = cudnnRNNForwardTraining(
501 cudnnHandle, rnnDesc, seqLength, desc.xDesc.data(), x.GetDataPointer(), hx.GetTensorDescriptor(), (rememberState) ?
502 hx.GetDataPointer() : nullptr, (isLSTM) ? cx.GetTensorDescriptor() : hx.GetTensorDescriptor(), (isLSTM) ? cx.GetDataPointer() : nullptr, desc.WeightsDescriptor,
503 weights.GetDataPointer(), desc.yDesc.data(), y.GetDataPointer(), hy.GetTensorDescriptor(), hy.GetDataPointer(),
504 (isLSTM) ? cy.GetTensorDescriptor() : hy.GetTensorDescriptor(), (isLSTM) ? cy.GetDataPointer() : nullptr, workspace.ForwardWorkspace, workspace.ForwardWorkspaceSize,
505 workspace.HelperWorkspace, workspace.HelperWorkspaceSize);
506
507 assert(status == CUDNN_STATUS_SUCCESS);
508 CUDNNCHECK(status);
509
510 }
511 else {
512 // perform inference
513 cudnnStatus_t status = cudnnRNNForwardInference(
514 cudnnHandle, rnnDesc, seqLength, desc.xDesc.data(), x.GetDataPointer(), hx.GetTensorDescriptor(),
515 (rememberState) ? hx.GetDataPointer() : nullptr,
516 (isLSTM) ? cx.GetTensorDescriptor() : hx.GetTensorDescriptor(), (isLSTM) ? cx.GetDataPointer() : nullptr,
517 desc.WeightsDescriptor, weights.GetDataPointer(), desc.yDesc.data(), y.GetDataPointer(),
518 hy.GetTensorDescriptor(), hy.GetDataPointer(), (isLSTM) ? cy.GetTensorDescriptor() : hy.GetTensorDescriptor(),
519 (isLSTM) ? cy.GetDataPointer() : nullptr, workspace.ForwardWorkspace, workspace.ForwardWorkspaceSize);
520
521 assert(status == CUDNN_STATUS_SUCCESS);
522 CUDNNCHECK(status);
523 }
524}
525
526//____________________________________________________________________________
527template <typename AFloat>
528void TCudnn<AFloat>::RNNBackward(const Tensor_t &x, const Tensor_t &hx, const Tensor_t &cx, const Tensor_t &y,
529 const Tensor_t &dy, const Tensor_t &dhy, const Tensor_t &dcy, const Tensor_t &weights,
530 Tensor_t &dx, Tensor_t &dhx, Tensor_t &dcx, Tensor_t &dw, const RNNDescriptors_t &desc,
531 RNNWorkspace_t &workspace)
532
533{
534 bool rememberState = false;
535 bool rememberStateGrad = false;
536 bool isLSTM = (cx.GetSize() > 0) && rememberState;
537 int seqLength = x.GetShape()[0];
538 cudnnRNNDescriptor_t rnnDesc = desc.LayerDescriptor;
539 cudnnHandle_t cudnnHandle = x.GetCudnnHandle();
540
541 // first data gradients (if dx is a summy tensor is first layer and we skip the data gradients )
542 //if (dx.GetSize() > 0) {
543 // cudnn neeeds to call backwared data to make it work !!!
544 //cudnnStatus_t status;
545 cudnnStatus_t status = cudnnRNNBackwardData(
546 cudnnHandle, rnnDesc, seqLength, desc.yDesc.data(), y.GetDataPointer(), desc.dyDesc.data(), dy.GetDataPointer(),
547 dhy.GetTensorDescriptor(), (rememberStateGrad) ? dhy.GetDataPointer() : nullptr,
548 (isLSTM) ? dcy.GetTensorDescriptor() : dhy.GetTensorDescriptor(), (isLSTM) ? dcy.GetDataPointer() : nullptr, // dcy
549 desc.WeightsDescriptor, weights.GetDataPointer(), hx.GetTensorDescriptor(),
550 (rememberState) ? hx.GetDataPointer() : nullptr, (isLSTM) ? cx.GetTensorDescriptor() : hx.GetTensorDescriptor(),
551 (isLSTM) ? cx.GetDataPointer() : nullptr, // cx
552 desc.dxDesc.data(), dx.GetDataPointer(), dhx.GetTensorDescriptor(),
553 (rememberState) ? dhx.GetDataPointer() : nullptr,
554 (isLSTM) ? dcx.GetTensorDescriptor() : dhx.GetTensorDescriptor(),
555 (isLSTM) ? dcx.GetDataPointer() : nullptr, // dcx
556 workspace.ForwardWorkspace, workspace.ForwardWorkspaceSize, workspace.HelperWorkspace,
557 workspace.HelperWorkspaceSize);
558
559 assert(status == CUDNN_STATUS_SUCCESS);
560 CUDNNCHECK(status);
561
562 // now the weights
563 //PrintTensor(dw, "weight grad before");
564 // std::cout << "RNN Backward weights !!! -remmber state" << rememberState << std::endl;
565 // PrintTensor(x, "x");
566 // PrintTensor(hx, "hx");
567 // PrintTensor(y, "y");
568 // PrintTensor(dx, "dx");
569 // PrintTensor(dw, "dw");
570
571 status = cudnnRNNBackwardWeights(cudnnHandle, rnnDesc, seqLength, desc.xDesc.data(), x.GetDataPointer(),
572 hx.GetTensorDescriptor(), (rememberState) ? dhx.GetDataPointer() : nullptr,
573 desc.yDesc.data(), y.GetDataPointer(), workspace.ForwardWorkspace,
574 workspace.ForwardWorkspaceSize, desc.WeightsGradDescriptor, dw.GetDataPointer(),
575 workspace.HelperWorkspace, workspace.HelperWorkspaceSize);
576
577 assert(status == CUDNN_STATUS_SUCCESS);
578 CUDNNCHECK(status);
579
580 // PrintTensor(dw, "weight grad after");
581}
582
583
584template<typename AFloat>
585void TCudnn<AFloat>::Rearrange(Tensor_t & y, const Tensor_t & x) {
586
587 AFloat alpha = 1;
588 AFloat beta = 0;
589 cudnnHandle_t cudnnHandle = x.GetCudnnHandle();
590 // x can be a tensor of dimension 3 or dimension 4
591 Tensor_t tmp = x;
592 TensorDescriptor_t d = tmp.GetTensorDescriptor();
593 int n = 0;
594 int dims[4];
595 int strides[4];
596 cudnnDataType_t dataType;
597 cudnnGetTensorNdDescriptor(d,tmp.GetNDim() , &dataType, &n, dims, strides);
598 assert(n >=3);
599
600 // assume x shape is B x T x S or B x T x 1 x S and y shape is T x B x S
601 const int xNdim = 3;
602 auto outputShape = y.GetShape();
603 assert(xNdim == y.GetNDim());
604 // swap from x to y first 2 dimension
605 assert(outputShape[0] = dims[1]); // T
606 assert(outputShape[1] == dims[0]); // B
607 assert(outputShape[2] == (n ==4) ? dims[3] : dims[2]); // S
608 if (n==4) assert(dims[2] == 1);
609
610
611 // input stride of T is S and of B is TxS
612 int xStrides[xNdim] = { (int) outputShape[2], (int)(outputShape[2] * outputShape[0]), 1 };
613 int xDims[xNdim];
614 for (int i = 0; i < xNdim; ++i)
615 xDims[i] = outputShape[i];
616
617 cudnnStatus_t status = cudnnSetTensorNdDescriptor(d, dataType, xNdim, xDims, xStrides);
618 assert(status == CUDNN_STATUS_SUCCESS);
619 CUDNNCHECK(status);
620 status = cudnnTransformTensor(cudnnHandle, &alpha, d, x.GetDataPointer() , &beta,
621 y.GetTensorDescriptor(), y.GetDataPointer());
622 assert(status == CUDNN_STATUS_SUCCESS);
623 CUDNNCHECK(status);
624
625 // reset original descriptor in tensor x
626 status = cudnnSetTensorNdDescriptor(d, dataType, n, dims, strides);
627 assert(status == CUDNN_STATUS_SUCCESS);
628
629 //PrintTensor(x, "x as B x T x S");
630 //PrintTensor(y, "y as T x B x S");
631}
632
633} // namespace DNN
634} // namespace TMVA
#define d(i)
Definition RSha256.hxx:102
#define b(i)
Definition RSha256.hxx:100
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
#define R__ASSERT(e)
Definition TError.h:118
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h offset
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char mode
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
double beta(double x, double y)
Calculates the beta function.
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
PyObject * GetDataPointer(PyObject *self, PyObject *args)
Get pointer to the data of an object.
create variable transformations