Skip to content

Conversation

@ifedan
Copy link
Contributor

@ifedan ifedan commented Sep 6, 2019

@jacobrgardner #15253 (comment) preposed a way to speedup euclidean distance calculation. This PR is implementation of this solution for normal and batch version.

Also @simonepri provided performance metrics #15253 (comment)
image

Current implementation has speedup comparing to @jacobrgardner approach
image

@ifedan ifedan marked this pull request as ready for review September 6, 2019 21:26
@jacobrgardner
Copy link

Awesome!

Just for completeness, here is a self contained version of our current distance function in gpytorch, which is about 2x faster than the original snippet I posted in that issue:

def fast_cdist(x1, x2):
    adjustment = x1.mean(-2, keepdim=True)
    x1 = x1 - adjustment
    x2 = x2 - adjustment  # x1 and x2 should be identical in all dims except -2 at this point

    # Compute squared distance matrix using quadratic expansion
    # But be clever and do it with a single matmul call
    x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
    x1_pad = torch.ones_like(x1_norm)
    x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
    x2_pad = torch.ones_like(x2_norm)
    x1_ = torch.cat([-2. * x1, x1_norm, x1_pad], dim=-1)
    x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
    res = x1_.matmul(x2_.transpose(-2, -1))

    # Zero out negative values
    res.clamp_min_(1e-30).sqrt_()
    return res

This uses the same basic strategy as the original method I'd written, but does everything in a single matmul call because matmul calls are really just that fast.

That being said, it has some drawbacks for inclusion upstream in pytorch proper over my original snippet. Namely (1) the additional working memory overhead, and (2) it's the same basic strategy, so a proper implementation of the first snippet I'd linked may erase the gap anyways.

@ifedan
Copy link
Contributor Author

ifedan commented Sep 6, 2019

matmul

Awesome!

Just for completeness, here is a self contained version of our current distance function in gpytorch, which is about 2x faster than the original snippet I posted in that issue:

def fast_cdist(x1, x2):
    adjustment = x1.mean(-2, keepdim=True)
    x1 = x1 - adjustment
    x2 = x2 - adjustment  # x1 and x2 should be identical in all dims except -2 at this point

    # Compute squared distance matrix using quadratic expansion
    # But be clever and do it with a single matmul call
    x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
    x1_pad = torch.ones_like(x1_norm)
    x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
    x2_pad = torch.ones_like(x2_norm)
    x1_ = torch.cat([-2. * x1, x1_norm, x1_pad], dim=-1)
    x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
    res = x1_.matmul(x2_.transpose(-2, -1))

    # Zero out negative values
    res.clamp_min_(1e-30).sqrt_()
    return res

This uses the same basic strategy as the original method I'd written, but does everything in a single matmul call because matmul calls are really just that fast.

That being said, it has some drawbacks for inclusion upstream in pytorch proper over my original snippet. Namely (1) the additional working memory overhead, and (2) it's the same basic strategy, so a proper implementation of the first snippet I'd linked may erase the gap anyways.

@jacobrgardner Thanks, let me check the performance metrics.

@simonepri
Copy link

simonepri commented Sep 6, 2019

Great work!

FYI, I just added a reference to an implementation for the batched version in #15253 (comment)
The performance metrics were for that version. See facebookresearch/PyTorch-BigGraph#67 (comment).

cc: @lerks

} else {
result = at::baddbmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), 1, -2);
}
result.add_(x1_norm);
Copy link
Collaborator

@ngimel ngimel Sep 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementations in #15253 had a clamp call filtering negative values, and it is necessary, it is possible to get negative distances with this approach because of floating point accuracy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I changed this

Balandat added a commit to cornellius-gp/gpytorch that referenced this pull request Sep 12, 2019
`torch.cdist` has been a pain for a long time, it's buggy and slow.
A more fundamental issue is that we use `torch.cdist(x1, x2).pow(2)` in the cdist code path: https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/kernels/kernel.py#L35
While the squared distance is differentiable if the points are the same, this is not true for the distance. But autograd doesn't know that, so it can completely mess up things if either there are repeated points, or if we test at a train point.

For now, let's kill it for now until there is a reasonable implementation (including of the squared distances) available.
There is work on improving perf here: pytorch/pytorch#25799, will add a request for returning the squared distances.
@Balandat
Copy link
Contributor

@ifedan, can we also add a squared kwarg (or maybe a torch.sqcdist function) that returns the squared differences? This is important for kernels methods, when we want to differentiate through some function fo the squared distance between the same point (i.e. at zero).

While the squared distance is differentiable at zero, this is not true for the distance. But autograd doesn't know that, so doing cidst(x1, x2).pow(2) is not an option and we need the squared distances directly.

Balandat added a commit to cornellius-gp/gpytorch that referenced this pull request Sep 12, 2019
`torch.cdist` has been a pain for a long time, it's buggy and slow.
A more fundamental issue is that we use `torch.cdist(x1, x2).pow(2)` in the cdist code path: https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/kernels/kernel.py#L35
While the squared distance is differentiable if the points are the same, this is not true for the distance. But autograd doesn't know that, so it can completely mess up things if either there are repeated points, or if we test at a train point.

For now, let's kill it for now until there is a reasonable implementation (including of the squared distances) available.
There is work on improving perf here: pytorch/pytorch#25799, will add a request for returning the squared distances.
@pytorchbot pytorchbot added module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen labels Sep 20, 2019
@ifedan
Copy link
Contributor Author

ifedan commented Sep 27, 2019

@pytorchbot retest this please

@Balandat
Copy link
Contributor

Balandat commented Oct 1, 2019

@ifedan any thoughts on being able to return the squared (or more generally p-ed) distance?

That is have an option to return \sum_k (x_ik - xj_k)^p instead of (\sum_k (x_ik - xj_k)^p(^(1/p).

This is not just as easy as squaring the result of cdist, as cdist is not differentiable at zero.

@ngimel
Copy link
Collaborator

ngimel commented Oct 16, 2019

Did you move cdist to functional.py because optional kwargs are not handled well in native_functions?
cc @houseroad for backward compatibility (an optional arg is added to cdist)

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an optional parameter with default value is backward compatible. Looks good to me.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ifedan
Copy link
Contributor Author

ifedan commented Oct 17, 2019

@pytorchbot retest this please

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 17, 2019
Summary:
jacobrgardner pytorch/pytorch#15253 (comment) preposed a way to speedup euclidean distance calculation. This PR is implementation of this solution for normal and batch version.

Also simonepri provided performance metrics pytorch/pytorch#15253 (comment)
![image](https://user-images.githubusercontent.com/12058312/64460756-44a24580-d0c9-11e9-9f7f-a5942f4c832d.png)

Current implementation has speedup comparing to jacobrgardner approach
![image](https://user-images.githubusercontent.com/12058312/64461495-5553bb00-d0cb-11e9-87e6-302b8cc7e12b.png)
Pull Request resolved: pytorch/pytorch#25799

Differential Revision: D17964982

Pulled By: ifedan

fbshipit-source-id: bf7bd0dbfca51fd39e667da55139347480f30a2f
@facebook-github-bot
Copy link
Contributor

@ifedan merged this pull request in 12dde7f.

thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
Summary:
jacobrgardner pytorch#15253 (comment) preposed a way to speedup euclidean distance calculation. This PR is implementation of this solution for normal and batch version.

Also simonepri provided performance metrics pytorch#15253 (comment)
![image](https://user-images.githubusercontent.com/12058312/64460756-44a24580-d0c9-11e9-9f7f-a5942f4c832d.png)

Current implementation has speedup comparing to jacobrgardner approach
![image](https://user-images.githubusercontent.com/12058312/64461495-5553bb00-d0cb-11e9-87e6-302b8cc7e12b.png)
Pull Request resolved: pytorch#25799

Differential Revision: D17964982

Pulled By: ifedan

fbshipit-source-id: bf7bd0dbfca51fd39e667da55139347480f30a2f
@buttercutter
Copy link

What do you guys think about https://stackoverflow.com/questions/61241523/replacing-torch-cdist-function-to-eliminate-gpu-out-of-memory-runtime-error ?

def new_cdist(p, eta):
    class cdist(torch.autograd.Function):
        @staticmethod
        def forward(ctx, W, X):
            ctx.save_for_backward(W, X)
            out = -torch.cdist(W, X, p)
            return out

        @staticmethod
        def backward(ctx, grad_output):
            W, X = ctx.saved_tensors
            grad_W = grad_X = None
            if ctx.needs_input_grad[0]:
                _temp1 = torch.unsqueeze(X, 2).expand(X.shape[0], X.shape[1], W.shape[0]).permute(1, 0, 2)
                _temp2 = torch.unsqueeze(W.transpose(0, 1), 1)
                _temp = torch.cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
                grad_W = torch.matmul(grad_output, _temp)
                # print('before norm: ', torch.norm(grad_W))
                grad_W = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W) * grad_W
                print('after norm: ', torch.norm(grad_W))
            if ctx.needs_input_grad[1]:
                _temp1 = torch.unsqueeze(W, 2).expand(W.shape[0], W.shape[1], X.shape[0]).permute(1, 0, 2)
                _temp2 = torch.unsqueeze(X.transpose(0, 1), 1)
                _temp = torch.cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
                _temp = torch.nn.functional.hardtanh(_temp, min_val=-1., max_val=1.)
                grad_X = torch.matmul(grad_output.transpose(0, 1), _temp)
            return grad_W, grad_X
    return cdist().apply

@ngimel
Copy link
Collaborator

ngimel commented Apr 17, 2020

This (or very similar) was implemented in #31167

@buttercutter
Copy link

@ngimel are you saying that I should install the nightly version of pytorch for #31167 ?

@jacobrgardner As for #25799 (comment) , how to replace torch.cat() which contains contiguous tensors ?

x1_ = torch.cat([-2. * x1, x1_norm, x1_pad], dim=-1)
x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)

@ngimel
Copy link
Collaborator

ngimel commented Apr 26, 2020

You can install either nightly or 1.5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen module: typing Related to mypy type annotations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants