Skip to content

Conversation

@skrah
Copy link
Contributor

@skrah skrah commented May 13, 2019

This addresses #18862.

@skrah skrah added module: performance Issues related to performance, either of kernel code or framework glue module: operators module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 13, 2019
@skrah
Copy link
Contributor Author

skrah commented May 13, 2019

This PR implements more or less @colesbury's suggestion C from #18862. Suggestion C ends up calling the existing optimization for A.ndim >= 3 && B.ndim <= 2 after transposing the inner dimensions of the arguments and swapping them.

Memory usage

Memory usage as reported by valgrind --tool=massif.

Test case

x = torch.randn(4096, 4096)
y = torch.randn(192, 4096, 1)
z = torch.matmul(x, y)

Before

--------------------------------------------------------------------------------
  n        time(i)         total(B)   useful-heap(B) extra-heap(B)    stacks(B)
--------------------------------------------------------------------------------
 51 18,481,148,839   12,980,531,144   12,978,674,895     1,856,249            0
 52 18,557,183,694   12,980,531,144   12,978,674,895     1,856,249            0
 53 18,633,218,549   12,980,531,144   12,978,674,895     1,856,249            0
 54 18,688,264,583       17,333,504       15,568,202     1,765,302            0

After

--------------------------------------------------------------------------------
  n        time(i)         total(B)   useful-heap(B) extra-heap(B)    stacks(B)
--------------------------------------------------------------------------------
 76  5,124,473,874       96,953,336       95,088,436     1,864,900            0
 77  5,139,417,005       95,591,328       93,726,580     1,864,748            0
 78  5,155,408,421       22,011,848       20,148,194     1,863,654            0
 79  5,171,971,669       21,145,472       19,304,096     1,841,376            0
 80  5,180,817,596       17,866,168       16,086,322     1,779,846            0
 81  5,189,662,538        7,099,760        6,442,537       657,223            0

Timings

Test case

x = torch.randn(4096, 4096)
y = torch.randn(192, 4096, 1)

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    z = torch.matmul(x, y)

print(prof)

Before

Self CPU time total: 7.430s
CUDA time total: 7.430s

After

Self CPU time total: 445.886ms
CUDA time total: 445.883ms

@skrah
Copy link
Contributor Author

skrah commented May 13, 2019

For small matrices the speed is about the same:

x = torch.randn(64)
y = torch.randn(2, 64, 1)

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    for i in range(10000):
        z = torch.matmul(x, y)

Before

Self CPU time total: 1.076s
CUDA time total: 1.081s

After

Self CPU time total: 939.033ms
CUDA time total: 940.369ms

@skrah
Copy link
Contributor Author

skrah commented May 13, 2019

@pytorchbot retest this please.

@skrah
Copy link
Contributor Author

skrah commented May 14, 2019

The PR is conservative in not reusing the out_opt argument in the recursive call, but I don't see a big difference in the timings (both the old and the new code have the same timings with or without an explicit out arg).

@skrah
Copy link
Contributor Author

skrah commented May 14, 2019

@pytorchbot retest this please.

@skrah skrah changed the title [WIP] Add matmul optimization for the case A.ndim <= 2 && B.ndim >= 3 Add matmul optimization for the case A.ndim <= 2 && B.ndim >= 3 May 14, 2019
@ezyang ezyang requested a review from colesbury May 14, 2019 20:07
@ezyang
Copy link
Contributor

ezyang commented May 14, 2019

Hey @colesbury, do you think you can review this? If not, I'll look.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Please make sure there are correctness tests in test_torch.py and test_autograd.py that cover this case. Specifically:

dim_tensor1=1, dim_tensor2 >= 3
dim_tensor1=2, dim_tensor2 >= 3

@pytorchbot pytorchbot added the module: tests Issues related to tests (not the torch.testing module) label May 15, 2019
@skrah
Copy link
Contributor Author

skrah commented May 15, 2019

test_torch.py already has tests here (they fail if an error is introduced in the new code):

result = maybe_squeeze_result(l, r, l_matmul_fn(r))

I've added some basic test_autograd.py tests.

Also I've tested the code with a couple of throwaway scripts like this one, but they may be too long for the unit tests:

import torch
import numpy as np

# 2d, 3d
for N in range(1, 20):
  for M in range(1, 20):
    for P in range(1, 20):
      for O in range(1, 20):
        x = torch.arange(N*M).reshape(N, M)
        y = torch.arange(O*M*P).reshape(O, M, P)
        expected = torch.bmm(x.unsqueeze(0).expand(O, N, M), y)
        z = torch.matmul(x, y)
        if not torch.equal(z, expected):
          raise RuntimeError("different results: %s %s %s %s" % (N, M, P, O))
        # Check contiguity flags via numpy.
        ex = np.array(expected, copy=False)
        zz = np.array(z, copy=False)
        if ex.flags != zz.flags or ex[0].flags != zz[0].flags:
          raise RuntimeError("different flags: %s %s %s %s" % (N, M, P, O))

# 1d, 3d
N = 1
for M in range(1, 20):
  for P in range(1, 20):
    for O in range(1, 20):
      x = torch.arange(M)
      y = torch.arange(O*M*P).reshape(O, M, P)
      expected = torch.bmm(x.expand(O, N, M), y).reshape(O, P)
      z = torch.matmul(x, y)
      if not torch.equal(z, expected):
        raise RuntimeError("different results: %s %s %s %s" % (N, M, P, O))
      # Check contiguity flags via numpy.
      ex = np.array(expected, copy=False)
      zz = np.array(z, copy=False)
      if ex.flags != zz.flags or ex[0].flags != zz[0].flags:
        raise RuntimeError("different flags: %s %s %s %s" % (N, M, P, O))

@skrah
Copy link
Contributor Author

skrah commented May 15, 2019

@pytorchbot retest this please.

@skrah
Copy link
Contributor Author

skrah commented May 15, 2019

@colesbury Thanks for the comments, I think they have all been addressed.

@skrah skrah requested a review from colesbury May 15, 2019 13:08
@ezyang
Copy link
Contributor

ezyang commented May 15, 2019

Also I've tested the code with a couple of throwaway scripts like this one, but they may be too long for the unit tests

That's what the @slowTest decorator is for :)

@skrah
Copy link
Contributor Author

skrah commented May 15, 2019

That's what the @slowTest decorator is for :)

OK thanks, actually I can add that test with range(1, 10), then it just takes 1s (4s in a debug build).

Apparently test_matmul_4d_4d was banned at some point, it is in a list THESE_TAKE_WAY_TOO_LONG and I can no longer find the test itself. :)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request May 17, 2019
Summary:
This addresses #18862.
Pull Request resolved: pytorch/pytorch#20448

Differential Revision: D15393465

Pulled By: ezyang

fbshipit-source-id: 87e5b0ed8253ea00365f420d98ac96dd4e934028
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 8c9f4c5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: performance Issues related to performance, either of kernel code or framework glue module: tests Issues related to tests (not the torch.testing module) open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants