27 #ifndef MAXPOOLLAYER_H_
28 #define MAXPOOLLAYER_H_
58 template <
typename Architecture_t>
62 using Tensor_t =
typename Architecture_t::Tensor_t;
63 using Matrix_t =
typename Architecture_t::Matrix_t;
64 using Scalar_t =
typename Architecture_t::Scalar_t;
107 TMaxPoolLayer(
size_t BatchSize,
size_t InputDepth,
size_t InputHeight,
size_t InputWidth,
size_t FilterHeight,
108 size_t FilterWidth,
size_t StrideRows,
size_t StrideCols,
Scalar_t DropoutProbability);
167 template <
typename Architecture_t>
169 size_t filterHeight,
size_t filterWidth,
size_t strideRows,
170 size_t strideCols,
Scalar_t dropoutProbability)
172 batchSize, inputDepth, inputHeight, inputWidth, inputDepth,
173 TConvLayer<Architecture_t>::calculateDimension(inputHeight, filterHeight, 0, strideRows),
174 TConvLayer<Architecture_t>::calculateDimension(inputWidth, filterWidth, 0, strideCols), 0, 0, 0, 0, 0,
176 batchSize, inputDepth,
177 TConvLayer<Architecture_t>::calculateNLocalViews(inputHeight, filterHeight, 0, strideRows, inputWidth,
178 filterWidth, 0, strideCols),
180 fFilterDepth(inputDepth), fFilterHeight(filterHeight), fFilterWidth(filterWidth), fStrideRows(strideRows),
181 fStrideCols(strideCols),
182 fNLocalViews(
TConvLayer<Architecture_t>::calculateNLocalViews(inputHeight, filterHeight, 0, strideRows,
183 inputWidth, filterWidth, 0, strideCols)),
184 fDropoutProbability(dropoutProbability), fIndexTensor(batchSize, inputDepth, fNLocalViews)
191 template <
typename Architecture_t>
193 :
VGeneralLayer<Architecture_t>(layer), fFilterDepth(layer->GetFilterDepth()),
194 fFilterHeight(layer->GetFilterHeight()), fFilterWidth(layer->GetFilterWidth()),
195 fStrideRows(layer->GetStrideRows()), fStrideCols(layer->GetStrideCols()), fNLocalViews(layer->GetNLocalViews()),
196 fDropoutProbability(layer->GetDropoutProbability()), fIndexTensor(layer->GetIndexTensor().GetShape())
203 template <
typename Architecture_t>
205 :
VGeneralLayer<Architecture_t>(layer), fFilterDepth(layer.fFilterDepth), fFilterHeight(layer.fFilterHeight),
206 fFilterWidth(layer.fFilterWidth), fStrideRows(layer.fStrideRows), fStrideCols(layer.fStrideCols),
207 fNLocalViews(layer.fNLocalViews), fDropoutProbability(layer.fDropoutProbability),
208 fIndexTensor(layer.GetIndexTensor().GetShape())
215 template <
typename Architecture_t>
219 ReleaseDescriptors();
221 fDescriptors =
nullptr;
227 fWorkspace =
nullptr;
232 template <
typename Architecture_t>
235 if (applyDropout && (this->GetDropoutProbability() != 1.0)) {
236 Architecture_t::DropoutForward(input, fDescriptors, fWorkspace, this->GetDropoutProbability());
239 Architecture_t::Downsample(
242 this->GetFilterHeight(), this->GetFilterWidth(), this->GetStrideRows(), this->GetStrideCols());
246 template <
typename Architecture_t>
251 if (this->GetDropoutProbability() != 1.0) {
252 Architecture_t::DropoutBackward(this->GetActivationGradients(), fDescriptors, fWorkspace);
254 Architecture_t::MaxPoolLayerBackward(
255 gradients_backward, this->GetActivationGradients(), fIndexTensor, activations_backward, this->GetOutput(),
258 this->GetFilterHeight(), this->GetFilterWidth(), this->GetStrideRows(), this->GetStrideCols(),
259 this->GetNLocalViews());
263 template <
typename Architecture_t>
266 std::cout <<
" POOL Layer: \t";
267 std::cout <<
"( W = " << this->GetWidth() <<
" , ";
268 std::cout <<
" H = " << this->GetHeight() <<
" , ";
269 std::cout <<
" D = " << this->GetDepth() <<
" ) ";
271 std::cout <<
"\t Filter ( W = " << this->GetFilterWidth() <<
" , ";
272 std::cout <<
" H = " << this->GetFilterHeight() <<
" ) ";
274 if (this->GetOutput().GetSize() > 0) {
275 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetCSize()
276 <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" ) ";
278 std::cout << std::endl;
282 template <
typename Architecture_t>
296 template <
typename Architecture_t>
303 template <
typename Architecture_t>
305 Architecture_t::InitializePoolDescriptors(fDescriptors,
this);
308 template <
typename Architecture_t>
310 Architecture_t::ReleasePoolDescriptors(fDescriptors);
314 template <
typename Architecture_t>
325 template <
typename Architecture_t>