Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ bool CUDAHooks::supportsDilatedConvolutionWithCuDNN() const {
#endif
}

bool CUDAHooks::supportsDepthwiseConvolutionWithCuDNN() const {
#if AT_CUDNN_ENABLED()
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
// Check for Volta cores
if (prop->major >= 7) {
return true;
} else {
return false;
}
#else
return false;
#endif
}

long CUDAHooks::versionCuDNN() const {
#if AT_CUDNN_ENABLED()
return CUDNN_VERSION;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool compiledWithCuDNN() const override;
bool compiledWithMIOpen() const override;
bool supportsDilatedConvolutionWithCuDNN() const override;
bool supportsDepthwiseConvolutionWithCuDNN() const override;
long versionCuDNN() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ struct CAFFE2_API CUDAHooksInterface {
return false;
}

virtual bool supportsDepthwiseConvolutionWithCuDNN() const {
return false;
}

virtual long versionCuDNN() const {
AT_ERROR("Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
Expand Down
147 changes: 146 additions & 1 deletion aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct ConvParams {
bool is_stride_neg() const;
void view1d_as_2d();
bool use_cudnn(const at::Tensor& input) const;
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
bool use_miopen(const at::Tensor& input) const;
bool use_mkldnn(const at::Tensor& input) const;
bool use_nnpack(const at::Tensor& input) const;
Expand Down Expand Up @@ -187,6 +188,143 @@ auto ConvParams::is_depthwise(
weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels
}

// Check workload to activate fast depthwise FP16 cudnn conv kernels
bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) {
int w = input.size(3); // same as h
int ch = input.size(1);
int bs = input.size(0);
if (stride==1) {
if (w >= 7) {
// All batch sizes and nb_channels
if (w >= 112) {
return true;
}

// large nb_channels
if (ch >= 1024) {
if (w >= 56) {
return true;
} else if (bs >= 32) {
return true;
}
}

// batch_size specific
if (bs >= 128) {
if (ch >= 512) {
return true;
} else if (ch >= 64) {
if (w >= 14) {
return true;
}
} else if ((ch >= 32) && (w >=28)) {
return true;
}
} else if (bs >= 64) {
if ((ch >= 256) && (w >= 14)) {
return true;
} else if ((ch >= 32) && (w >= 28)) {
return true;
}
} else if (bs >= 32) {
if ((ch >= 256) && (w >= 14)) {
return true;
} else if ((ch >= 128) && (w >= 28)) {
return true;
} else if ((ch >= 32) && (w >= 56)) {
return true;
}
} else if (bs >= 16) {
if ((ch >= 1024) && (w >= 14)) {
return true;
}
if ((ch >= 256) && (w >= 28)) {
return true;
} else if ((ch >= 32) && (w >= 56)) {
return true;
}
} else if (bs >= 8) {
if ((ch >= 512) && (w >= 28)) {
return true;
} else if ((ch >= 64) && (w >= 56)) {
return true;
}
}
}
} else if (stride==2) {
if (ch < 256) {
return false;
}

if (w >= 7) {
if (bs >= 128) {
if (ch >= 1024) {
return true;
} else if ((ch >= 512) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 64) {
if ((ch >= 512) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 32) {
if ((ch >= 1024) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 16) {
if ((ch >= 512) && (w >= 28)) {
return true;
} else if (w >= 56) {
return true;
}
} else if (bs >= 8) {
if ((ch >= 1024) && (w >= 28)) {
return true;
} else if (w >= 56) {
return true;
}
} else if (bs >= 1) {
if ((ch >= 512) && (w >=112)) {
return true;
}
}
}
}
return false;
}

// Use cudnn for FP16 depthwise convolutions
auto ConvParams::use_cudnn_depthwise(
const at::Tensor& input, const at::Tensor& weight) const -> bool {
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
long cudnn_version = detail::getCUDAHooks().versionCuDNN();
bool kernel_cond = (cudnn_version >= 7600 &&
use_cudnn(input) &&
input.scalar_type() == kHalf && // only for FP16
weight.scalar_type() == kHalf &&
is_depthwise(input, weight) &&
weight.size(2) == weight.size(3) && // only square kernels
input.size(2) >= 7 && // min width/height 7
!is_dilated() && // no dilation supported
stride[0] == stride[1] && // equal strides
((weight.size(3) == 3) || (weight.size(3) == 1)) &&
input.size(1) >= 32); // min 32 channels supported)
if (kernel_cond) {
return check_cudnn_depthwise_workload(input, stride[0]);
} else {
return false;
}
} else {
return false;
}
}

static void check_shape_forward(const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& bias,
const ConvParams& params, bool input_is_mkldnn) {
Expand Down Expand Up @@ -407,7 +545,14 @@ at::Tensor _convolution(
auto stride = params.stride;
auto padding = params.padding;
auto dilation = params.dilation;
output = at::thnn_conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation);
if (params.use_cudnn_depthwise(input, weight)) {
output = at::cudnn_convolution(
input, weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);

} else {
output = at::thnn_conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation);
}
} else if (params.use_cudnn(input)) {
TORCH_CHECK(input.type() == weight.type(),
"Input type (", input.type().toString(), ") and weight type (", weight.type().toString(),
Expand Down