Skip to content

Conversation

@killeent
Copy link
Contributor

@killeent killeent commented Oct 10, 2017

Partially addresses #1708. Currently works for Conv2D but needs testing.

TODO:

  • Extend gradWeight kernel to handle padding, strides, and dilation
  • Extend all kernels to handle a depthwise multiplier, i.e a depthwise convolution with multiple filters per input channel
  • Add support for bias
  • Make gradWeight (and other kernels) more efficient by leveraging better CUDA programming paradigms
  • Benchmark performance on individual layers and models that leverage Depthwise Convolution, see if any other perf wins can be made
  • Finish writing unit tests

@killeent killeent changed the title [WIP] 2D Depthwise Convolution on the GPU [WIP] Spatial Depthwise Convolution on the GPU Oct 10, 2017
@killeent
Copy link
Contributor Author

Some preliminary results (the parameters for the layers and inputs are taken from MobileNet). All times are for running 50 iterations of Forward/Backward. The trends:

  • As the number of channels (and thus groups) increases, the new code becomes faster and faster, because we are replacing a bigger and bigger loop with a single call
  • As the batch size increases, these performance gains are dampened, but still non-trivial
Batch Size Input Channels Height Width kH kW Stride time (new) time (old) Speedup
1 32 112 112 1 1 1 0.2129 0.0076 28x
1 64 112 112 3 3 2 0.1937 0.0143 13.5x
1 128 56 56 3 3 1 0.4229 0.0241 17.5x
1 128 56 56 3 3 2 0.3264 0.0112 29x
1 256 28 28 3 3 1 0.7379 0.0207 35.5x
1 256 28 28 3 3 2 0.7522 0.0059 127x
1 512 14 14 3 3 1 1.406 0.0097 145x
1 512 14 14 3 3 2 1.494 0.0052 287x
1 1024 7 7 3 3 2 2.919 0.0053 550x
64 32 112 112 1 1 1 3.603 0.3102 11.5x
64 64 112 112 3 3 2 2.261 0.5031 4.5x
64 128 56 56 3 3 1 4.343 0.9285 4.5x
64 128 56 56 3 3 2 1.445 0.2310 6x
64 256 28 28 3 3 1 3.586 0.4187 8.5x
64 256 28 28 3 3 2 1.043 0.1168 9x
64 512 14 14 3 3 1 2.489 0.1957 12.5x
64 512 14 14 3 3 2 1.283 0.0676 18.5x
64 1024 7 7 3 3 2 2.238 0.0686 32.5x
128 32 112 112 1 1 1 7.008 0.5918 11.5x
128 64 112 112 3 3 2 3.825 0.9979 3.5x
128 128 56 56 3 3 1 8.922 1.844 4.5x
128 128 56 56 3 3 2 2.419 0.4574 5x
128 256 28 28 3 3 1 9.471 0.8292 11x
128 256 28 28 3 3 2 1.352 0.2224 6x
128 512 14 14 3 3 1 7.997 0.3735 21.5x
128 512 14 14 3 3 2 1.371 0.1128 12x
128 1024 7 7 3 3 2 2.383 0.0998 24x

@killeent killeent changed the title [WIP] Spatial Depthwise Convolution on the GPU Spatial Depthwise Convolution on the GPU Oct 13, 2017
@killeent
Copy link
Contributor Author

Marking this as ready to review. I still need to write tests, but while I do that I think its worth looking at the kernels, and also the integration, which is a little messy.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_nn.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_nn.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@killeent
Copy link
Contributor Author

Okay, I addressed everything except:

  • Half unit tests, I will need @colesbury to add support for binding THCUNN half operations in ATen, or do that myself
  • I made the kernels use mutliplication instead of division, and accumulation types, where appropriate, but did not address any other perf stuff. If you would like there to be further perf improvements, I would rather do those in a separate PR

auto padding = vecToInt64(this->padding);
auto dilation = vecToInt64(this->dilation);

at::conv_depthwise2d_forward_out(output, input, weight, kernel_size, bias, stride, padding, dilation);

This comment was marked as off-topic.

if (output_mask[2]) {
grad_bias = bias.type().tensor();
grad_bias.resize_as_(bias).zero_();
update_grad_bias(grad_output, grad_bias);

This comment was marked as off-topic.

value = THCNumerics<AccT>::add(
value,
THCNumerics<T>::mul(weight.data()[weightOffset], input.data()[offset]));
ScalarConvert<T, AccT>::to(THCNumerics<T>::mul(weight.data()[weightOffset], input.data()[offset])));

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@killeent
Copy link
Contributor Author

Okay, I added support for Half-Precision. Let me know if there is anything else I need to do.

@ngimel
Copy link
Collaborator

ngimel commented Oct 16, 2017

LGTM. Did you figure out support for binding THCUNN half operations in ATen?

@killeent
Copy link
Contributor Author

@ngimel - yeah, we just needed to add an extra parameter to @colesbury's nn_parse script - ATen itself already had support for handling half.

@ngimel
Copy link
Collaborator

ngimel commented Oct 16, 2017

Cool! Will it also fix #2435? It is still broken.

@killeent
Copy link
Contributor Author

@ngimel I ran that repro script and it didn't crash so I think so.

@KeCh96
Copy link

KeCh96 commented Feb 2, 2018

I have upgrade my pytorch to 0.3.0, but I found m = nn.Conv2d(128, 256, kernel_size=3, groups=128) is still 2 times slower than m = nn.Conv2d(128, 256, kernel_size=3). I am really confused by this problem, should I need to upgrade pytorch to other version? Or use other operation?

@cddlyf
Copy link

cddlyf commented Feb 6, 2018

I have tried the depthwise convolution with nn.Conv2d(64,64, 3, 1, 1, groups=64), but it is only around 2x faster than nn.Conv2d(64,64,3,1,1). The input size is 1x64x256x256, could you tell me what's wrong? my pytorch version is 0.3.0.post4

@cddlyf
Copy link

cddlyf commented Feb 7, 2018

@killeent @ngimel could you give me some hints? How to use the optimized depthwise convolution? Does it requires latest pytorch or cudnn or not?

@ouceduxzk
Copy link

@cddlyf me too, I found most existing code of separable_conv still use conv with group

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants