-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
PyTorch 1.1.0a0+c3e3c5c. The GPU times reported are on a P100.
In this case matmul uses about 12 GB of memory when it shouldn't use more than ~3 MB. (i.e. it's using 4096x more memory than necessary)
A
x = torch.randn(4096, 4096)
y = torch.randn(192, 4096, 1)
z = torch.matmul(x, y)Note that this is equivalent to the following memory efficient operation:
B
x = torch.randn(4096, 4096)
y = torch.randn(192, 4096, 1)
z = torch.bmm(x.unsqueeze(0).expand(192, *x.shape), y)
It's also equivalent to the following which is memory efficient and faster, but may require a copy of y and the output may be batched-column-major without some extra work:
C
x = torch.randn(4096, 4096)
y = torch.randn(192, 4096, 1)
z = torch.matmul(y.permute(0, 2, 1), x.t()).permute(0, 2, 1)
On GPU, A takes ~125 ms and uses 12 GB of memory, B takes ~22 ms, and C takes ~1 ms.
See also #13222 which may be related
I believe the problem is the unnecessary contiguous call here:
pytorch/aten/src/ATen/native/LinearAlgebra.cpp
Lines 460 to 461 in 15b318d
| Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).contiguous().view(tensor1_bmm_view); | |
| Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).contiguous().view(tensor2_bmm_view); |
Instead of using contiguous() and view() it may be possible to use reshape(). That might achieve performance of B.