Skip to content

Conversation

@ifedan
Copy link
Contributor

@ifedan ifedan commented May 16, 2019

Fix based on #15253

@pytorchbot pytorchbot added module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: operators labels May 16, 2019
@ifedan
Copy link
Contributor Author

ifedan commented May 16, 2019

Current CPU implementation:

import timeit
SETUP_CODE = '''
     import torch
     from scipy.spatial.distance import cdist
     a = torch.randn(100, 2)
     b = torch.randn(200, 2)'''

TEST_CODE = '''torch.cdist(a, b)'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10000)
[4.789991227999998, 4.700422251000006, 5.000459026999991]

TEST_CODE = '''cdist(a, b)'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10000)
[0.6307194809999999, 0.5901044249999998, 0.5976278489999984]


SETUP_CODE = '''
      import torch 
      from scipy.spatial.distance import cdist
      a = torch.randn(2, 200)
      b = torch.randn(2, 200)'''

TEST_CODE = '''torch.cdist(a, b)'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10000)
[0.027029143000000033, 0.02519462499999925, 0.0251151329999999]

TEST_CODE = '''cdist(a, b)'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10000)
[0.18516300100000294, 0.16735686699999874, 0.18243102399999955]

Improved CPU implementation:

import timeit
SETUP_CODE = '''
      import torch
      from scipy.spatial.distance import cdist
      a = torch.randn(100, 2)
      b = torch.randn(200, 2)'''

TEST_CODE = '''torch.cdist(a, b)'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10000)
[0.5074237539993192, 0.4883652280004753, 0.48943868999958795]

TEST_CODE = '''cdist(a, b)'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10000)
[0.6026854669999011, 0.6011174650011526, 0.5975362090011913]


SETUP_CODE = '''
      import torch
      from scipy.spatial.distance import cdist
      a = torch.randn(2, 200)
      b = torch.randn(2, 200)'''

TEST_CODE = '''torch.cdist(a, b)'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10000)
[0.04526426500160596, 0.03518587399958051, 0.029068126999845845]

TEST_CODE = '''cdist(a, b)'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10000)
[0.18631748199914, 0.17298765699888463, 0.17551641600039147]

@gchanan gchanan self-requested a review May 16, 2019 21:21
@ifedan
Copy link
Contributor Author

ifedan commented May 16, 2019

Current GPU implementation:

import timeit
SETUP_CODE = '''
     import torch
     from gpytorch.kernels.kernel import Distance
     dist = Distance()
     a = torch.randn(10000, 9).cuda()
     b = torch.randn(30000, 9).cuda()'''
TEST_CODE = '''D1=torch.cdist(a, b); print(D1[0, 0])'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10)
[47.1045588250272, 47.20872523600701, 47.47344793193042]

SETUP_CODE = '''
      import torch
      from gpytorch.kernels.kernel import Distance
      dist = Distance()
      a = torch.randn(9, 10000).cuda()
      b = torch.randn(9, 10000).cuda()'''
TEST_CODE = '''D1=torch.cdist(a, b); print(D1[0, 0])'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10)
[0.009627211024053395, 0.007205849979072809, 0.007276762975379825]

Improved GPU implementation:

import timeit
SETUP_CODE = '''
      import torch
      from gpytorch.kernels.kernel import Distance
      dist = Distance()
      a = torch.randn(10000, 9).cuda()
      b = torch.randn(30000, 9).cuda()'''
TEST_CODE = '''D1=torch.cdist(a, b); print(D1[0, 0])'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10)
[7.402155780000612, 7.373479370959103, 7.379088997957297]

SETUP_CODE = '''
      import torch
      from gpytorch.kernels.kernel import Distance
      dist = Distance()
      a = torch.randn(9, 10000).cuda()
      b = torch.randn(9, 10000).cuda()'''
TEST_CODE = '''D1=torch.cdist(a, b); print(D1[0, 0])'''
timeit.repeat(setup = SETUP_CODE,stmt = TEST_CODE,repeat = 3,number = 10)
[0.010255609988234937, 0.00916548096574843, 0.008794950088486075]

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ifedan
Copy link
Contributor Author

ifedan commented May 17, 2019

@pytorchbot retest this please

@gchanan gchanan requested a review from umanwizard May 20, 2019 20:26
@gchanan
Copy link
Contributor

gchanan commented May 20, 2019

An intuitive explanation for why this is faster in those cases would be nice.

@umanwizard
Copy link
Contributor

umanwizard commented May 20, 2019

If I understand this correctly, it looks like you are doing two different things:

(1) Removing the division from the tight loop.
(2) Making it un-vectorized.

I understand why (1) is a win, but not (2)...

@ifedan
Copy link
Contributor Author

ifedan commented May 23, 2019

Each row with size M in first tensor(a) is used with each row with size M in second tensor(b) to calculate cdist:

image

on CPU I vectorize the data and parallel it as following. Then I use Map-Reduce to calculate correspondent result.

image

on GPU the grid size = R1*R2 and number of threads per block = 256. I use Map-Reduce within warpSize to calculate result.

image

There are two issues:

  1. On GPU, as you can see, if M is smaller then number of threads then we will end-up in situation when a lot of thread do nothing. To avoid this I will change number of thread per block based on M. It will not be smaller than warpSize.
  2. On CPU, there is an overhead on vector reduce stage: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cpu/vec256/functional.h#L14, https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cpu/vec256/functional.h#L21
    Based on perf test, avoiding vectorization will speedup this process.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ifedan
Copy link
Contributor Author

ifedan commented May 24, 2019

@pytorchbot retest this please

@facebook-github-bot
Copy link
Contributor

@ifedan merged this pull request in a2328a2.

zdevito pushed a commit to zdevito/ATen that referenced this pull request May 24, 2019
Summary:
Fix based on pytorch/pytorch#15253
Pull Request resolved: pytorch/pytorch#20605

Differential Revision: D15396123

Pulled By: ifedan

fbshipit-source-id: 3ed373e68339a35360f083d4aad1b655abcaf97e
@ifedan ifedan deleted the cdist_perf branch September 6, 2019 20:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants