-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add matmul optimization for the case A.ndim <= 2 && B.ndim >= 3 #20448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This PR implements more or less @colesbury's suggestion C from #18862. Suggestion C ends up calling the existing optimization for Memory usage Memory usage as reported by Test case Before After Timings Test case Before After |
|
For small matrices the speed is about the same: Before After |
|
@pytorchbot retest this please. |
|
The PR is conservative in not reusing the |
|
@pytorchbot retest this please. |
|
Hey @colesbury, do you think you can review this? If not, I'll look. |
colesbury
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Please make sure there are correctness tests in test_torch.py and test_autograd.py that cover this case. Specifically:
dim_tensor1=1, dim_tensor2 >= 3
dim_tensor1=2, dim_tensor2 >= 3
|
Line 4137 in 8e26759
I've added some basic Also I've tested the code with a couple of throwaway scripts like this one, but they may be too long for the unit tests: import torch
import numpy as np
# 2d, 3d
for N in range(1, 20):
for M in range(1, 20):
for P in range(1, 20):
for O in range(1, 20):
x = torch.arange(N*M).reshape(N, M)
y = torch.arange(O*M*P).reshape(O, M, P)
expected = torch.bmm(x.unsqueeze(0).expand(O, N, M), y)
z = torch.matmul(x, y)
if not torch.equal(z, expected):
raise RuntimeError("different results: %s %s %s %s" % (N, M, P, O))
# Check contiguity flags via numpy.
ex = np.array(expected, copy=False)
zz = np.array(z, copy=False)
if ex.flags != zz.flags or ex[0].flags != zz[0].flags:
raise RuntimeError("different flags: %s %s %s %s" % (N, M, P, O))
# 1d, 3d
N = 1
for M in range(1, 20):
for P in range(1, 20):
for O in range(1, 20):
x = torch.arange(M)
y = torch.arange(O*M*P).reshape(O, M, P)
expected = torch.bmm(x.expand(O, N, M), y).reshape(O, P)
z = torch.matmul(x, y)
if not torch.equal(z, expected):
raise RuntimeError("different results: %s %s %s %s" % (N, M, P, O))
# Check contiguity flags via numpy.
ex = np.array(expected, copy=False)
zz = np.array(z, copy=False)
if ex.flags != zz.flags or ex[0].flags != zz[0].flags:
raise RuntimeError("different flags: %s %s %s %s" % (N, M, P, O)) |
|
@pytorchbot retest this please. |
|
@colesbury Thanks for the comments, I think they have all been addressed. |
That's what the |
OK thanks, actually I can add that test with Apparently |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This addresses #18862. Pull Request resolved: pytorch/pytorch#20448 Differential Revision: D15393465 Pulled By: ezyang fbshipit-source-id: 87e5b0ed8253ea00365f420d98ac96dd4e934028
This addresses #18862.