Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Oct 30, 2021

It appears that most NVIDIA architectures (well, at least there haven't been many reports of this issue) don't do reduced precision reductions (e.g., reducing in fp16 given fp16 inputs), but this change attempts to ensure that a reduced precision reduction is never done. The included test case currently fails on Volta but passes on Pascal and Ampere; setting this flag causes the test to pass on all three.

CC @stas00 @ngimel @ptrblck

@pytorch-probot
Copy link

pytorch-probot bot commented Oct 30, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/eqy/pytorch/blob/e234557c09027412c9058e2a270c36dfd848e2ba/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-dynamic ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-xenial-py3-clang5-mobile-code-analysis ciflow/all, ciflow/linux, ciflow/mobile 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Oct 30, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit e234557 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Oct 31, 2021
Summary:
It appears that most NVIDIA architectures (well, at least there haven't been many reports of this issue) don't do reduced precision reductions (e.g., reducing in fp16 given fp16 inputs), but this change attempts to ensure that a reduced precision reduction is never done. The included test case currently fails on Volta but passes on Pascal and Ampere; setting this flag causes the test to pass on all three.

CC stas00 ngimel ptrblck

Pull Request resolved: #67578

Reviewed By: mruberry

Differential Revision: D32046030

Pulled By: ngimel

fbshipit-source-id: ac9aa8489ad6835f34bd0300c5d6f4ea76f333d1
@stas00
Copy link
Contributor

stas00 commented Oct 31, 2021

Why is commit information hidden from the public? e.g if I click on https://github.com/pytorch/pytorch-canary/commit/4e6cedb340cdeff97af3baf72888dc4b20b639c7 it gives me 404 not-found, which usually means it's a private repo.

So how can a user see where did the PR go?

It's already super confusing that pytorch PRs get closed rather than merged, as far as github PR status goes, since it's hard to quickly tell whether PR was merged or not, but if we can't even see where they go, that's troublesome...

Thank you!

update: I found it. e01279c
It's in a totally unexpected place, it's in the "added a commit that referenced this issue 12 hours ago" which is not where one would expect to find the destination, as this is typically github's way of adding references to other people doing something with this PR for their own needs, but I guess this is how it's with pytorch. It's still not very intuitive, but at least now I know where to look for the commit.

@ngimel
Copy link
Collaborator

ngimel commented Oct 31, 2021

@stas00 you can also look at hud https://hud.pytorch.org/ci/pytorch/pytorch/master for the latest merged commits (although sometimes there's a half an hour or so delay), or at the pytorch repo history.

@eqy
Copy link
Collaborator Author

eqy commented Nov 5, 2021

Doing some belated benchmarking on V100:
GEMMs are in [m, k, n]
BADDBMMs are in [b, m, k, n]

[----------- bench_gemm ----------]	------]
                          |   old	  new 
1 threads: ------------------------	-------
      [1, 1, 1]           |    27.6	   28.8
      [1, 1, 8]           |    11.7	   10.9
      [1, 1, 64]          |    11.4	   10.9
      [1, 1, 512]         |    11.5	   11.0
      [1, 1, 4096]        |    11.8	   11.2
      [1, 8, 1]           |    11.4	   10.8
      [1, 8, 8]           |    11.3	   10.8
      [1, 8, 64]          |    17.8	   17.3
      [1, 8, 512]         |    24.9	   24.2
      [1, 8, 4096]        |    12.0	   11.3
      [1, 64, 1]          |    11.4	   10.8
      [1, 64, 8]          |    11.3	   11.1
      [1, 64, 64]         |    18.0	   17.5
      [1, 64, 512]        |    25.1	   24.5
      [1, 64, 4096]       |    11.7	   11.0
      [1, 512, 1]         |    17.0	   16.3
      [1, 512, 8]         |    17.3	   16.5
      [1, 512, 64]        |    17.7	   17.4
      [1, 512, 512]       |    11.3	   10.8
      [1, 512, 4096]      |    11.2	   10.9
      [1, 4096, 1]        |    16.8	   16.6
      [1, 4096, 8]        |    17.0	   16.8
      [1, 4096, 64]       |    17.7	   17.4
      [1, 4096, 512]      |    25.0	   24.7
      [1, 4096, 4096]     |    45.4	   46.5
      [8, 1, 1]           |    11.7	   10.8
      [8, 1, 8]           |    11.7	   10.7
      [8, 1, 64]          |    11.4	   10.7
      [8, 1, 512]         |    11.6	   10.8
      [8, 1, 4096]        |    11.5	   11.0
      [8, 8, 1]           |    16.8	   16.7
      [8, 8, 8]           |    11.2	   10.9
      [8, 8, 64]          |    11.3	   11.2
      [8, 8, 512]         |    11.4	   11.1
      [8, 8, 4096]        |    11.5	   11.5
      [8, 64, 1]          |    17.0	   16.7
      [8, 64, 8]          |    11.4	   10.9
      [8, 64, 64]         |    11.6	   11.1
      [8, 64, 512]        |    11.6	   10.9
      [8, 64, 4096]       |    17.1	   16.9
      [8, 512, 1]         |    17.2	   17.0
      [8, 512, 8]         |    12.7	   11.8
      [8, 512, 64]        |    12.7	   11.7
      [8, 512, 512]       |    25.7	   25.9
      [8, 512, 4096]      |    18.0	   17.6
      [8, 4096, 1]        |    17.0	   16.8
      [8, 4096, 8]        |    24.4	   24.3
      [8, 4096, 64]       |    25.6	   26.2
      [8, 4096, 512]      |    24.5	   25.2
      [8, 4096, 4096]     |    62.2	   62.2
      [64, 1, 1]          |    11.4	   10.8
      [64, 1, 8]          |    11.5	   10.9
      [64, 1, 64]         |    11.6	   10.9
      [64, 1, 512]        |    11.7	   10.9
      [64, 1, 4096]       |    11.6	   10.9
      [64, 8, 1]          |    11.7	   10.9
      [64, 8, 8]          |    11.4	   11.2
      [64, 8, 64]         |    11.6	   11.4
      [64, 8, 512]        |    11.5	   11.3
      [64, 8, 4096]       |    16.5	   16.2
      [64, 64, 1]         |    11.3	   11.0
      [64, 64, 8]         |    12.6	   11.7
      [64, 64, 64]        |    12.6	   11.6
      [64, 64, 512]       |    17.9	   17.6
      [64, 64, 4096]      |    16.2	   16.1
      [64, 512, 1]        |    11.5	   11.0
      [64, 512, 8]        |    12.4	   11.6
      [64, 512, 64]       |    25.2	   24.9
      [64, 512, 512]      |    18.9	   18.6
      [64, 512, 4096]     |    19.4	   19.4
      [64, 4096, 1]       |    17.3	   17.3
      [64, 4096, 8]       |    24.4	   24.7
      [64, 4096, 64]      |    25.5	   25.0
      [64, 4096, 512]     |    25.3	   23.5
      [64, 4096, 4096]    |    63.4	   69.5
      [512, 1, 1]         |    11.7	   10.8
      [512, 1, 8]         |    11.8	   10.8
      [512, 1, 64]        |    11.7	   10.8
      [512, 1, 512]       |    11.7	   10.9
      [512, 1, 4096]      |    12.1	   11.2
      [512, 8, 1]         |    11.5	   11.0
      [512, 8, 8]         |    11.5	   11.5
      [512, 8, 64]        |    11.5	   11.4
      [512, 8, 512]       |    11.6	   11.3
      [512, 8, 4096]      |    16.9	   16.9
      [512, 64, 1]        |    11.5	   11.0
      [512, 64, 8]        |    11.9	   11.5
      [512, 64, 64]       |    17.6	   16.8
      [512, 64, 512]      |    17.5	   17.8
      [512, 64, 4096]     |    17.8	   18.2
      [512, 512, 1]       |    11.5	   11.1
      [512, 512, 8]       |    18.8	   18.3
      [512, 512, 64]      |    18.6	   18.5
      [512, 512, 512]     |    24.6	   17.5
      [512, 512, 4096]    |    35.3	   35.4
      [512, 4096, 1]      |    24.4	   23.6
      [512, 4096, 8]      |    24.4	   23.3
      [512, 4096, 64]     |    25.8	   23.5
      [512, 4096, 512]    |    43.9	   56.1
      [512, 4096, 4096]   |   223.6	  222.2
      [4096, 1, 1]        |    11.8	   10.5
      [4096, 1, 8]        |    11.7	   10.5
      [4096, 1, 64]       |    11.7	   10.5
      [4096, 1, 512]      |    12.0	   10.5
      [4096, 1, 4096]     |    50.8	   50.9
      [4096, 8, 1]        |    11.5	   10.5
      [4096, 8, 8]        |    11.5	   11.0
      [4096, 8, 64]       |    16.5	   15.8
      [4096, 8, 512]      |    17.0	   16.1
      [4096, 8, 4096]     |    51.4	   51.1
      [4096, 64, 1]       |    11.6	   10.8
      [4096, 64, 8]       |    16.2	   15.6
      [4096, 64, 64]      |    16.8	   15.8
      [4096, 64, 512]     |    18.1	   17.4
      [4096, 64, 4096]    |    63.5	   63.5
      [4096, 512, 1]      |    15.6	   14.3
      [4096, 512, 8]      |    18.7	   17.4
      [4096, 512, 64]     |    16.4	   15.7
      [4096, 512, 512]    |    38.3	   38.5
      [4096, 512, 4096]   |   253.8	  251.9
      [4096, 4096, 1]     |    53.2	   51.5
      [4096, 4096, 8]     |    79.6	  105.1
      [4096, 4096, 64]    |    84.1	  110.6
      [4096, 4096, 512]   |   225.4	  226.0
      [4096, 4096, 4096]  |  1621.5	 1607.3
	
Times are in microseconds (us).	s).
	
[------------ bench_baddbmm ------------]	------------]
                              |    old   	  |    new
1 threads: ------------------------------	-------------
      [1, 1, 1, 1]            |      26.1	  |      24.8
      [1, 1, 1, 8]            |      26.4	  |      24.5
      [1, 1, 1, 64]           |      26.4	  |      24.3
      [1, 1, 1, 512]          |      27.0	  |      24.4
      [1, 1, 1, 4096]         |      26.6	  |      25.0
      [1, 1, 8, 1]            |      26.6	  |      25.0
      [1, 1, 8, 8]            |      26.1	  |      24.7
      [1, 1, 8, 64]           |      33.9	  |      33.0
      [1, 1, 8, 512]          |      40.9	  |      39.1
      [1, 1, 8, 4096]         |      27.1	  |      25.2
      [1, 1, 64, 1]           |      26.7	  |      24.8
      [1, 1, 64, 8]           |      26.6	  |      24.9
      [1, 1, 64, 64]          |      34.2	  |      32.6
      [1, 1, 64, 512]         |      40.9	  |      38.9
      [1, 1, 64, 4096]        |      27.2	  |      24.8
      [1, 1, 512, 1]          |      33.2	  |      31.5
      [1, 1, 512, 8]          |      33.7	  |      31.9
      [1, 1, 512, 64]         |      34.3	  |      32.6
      [1, 1, 512, 512]        |      26.4	  |      24.8
      [1, 1, 512, 4096]       |      26.7	  |      24.9
      [1, 1, 4096, 1]         |      33.6	  |      32.0
      [1, 1, 4096, 8]         |      33.0	  |      32.2
      [1, 1, 4096, 64]        |      34.3	  |      32.7
      [1, 1, 4096, 512]       |      41.6	  |      40.1
      [1, 1, 4096, 4096]      |      47.3	  |      48.3
      [1, 8, 1, 1]            |      26.7	  |      24.7
      [1, 8, 1, 8]            |      26.3	  |      24.4
      [1, 8, 1, 64]           |      26.5	  |      24.3
      [1, 8, 1, 512]          |      26.5	  |      24.5
      [1, 8, 1, 4096]         |      26.3	  |      24.7
      [1, 8, 8, 1]            |      33.4	  |      32.1
      [1, 8, 8, 8]            |      25.8	  |      24.7
      [1, 8, 8, 64]           |      25.7	  |      24.6
      [1, 8, 8, 512]          |      26.4	  |      24.8
      [1, 8, 8, 4096]         |      26.8	  |      24.8
      [1, 8, 64, 1]           |      33.8	  |      32.1
      [1, 8, 64, 8]           |      26.6	  |      24.6
      [1, 8, 64, 64]          |      26.2	  |      24.6
      [1, 8, 64, 512]         |      26.5	  |      24.8
      [1, 8, 64, 4096]        |      32.3	  |      31.3
      [1, 8, 512, 1]          |      32.9	  |      32.0
      [1, 8, 512, 8]          |      28.9	  |      27.0
      [1, 8, 512, 64]         |      28.8	  |      26.9
      [1, 8, 512, 512]        |      42.6	  |      42.0
      [1, 8, 512, 4096]       |      34.0	  |      33.1
      [1, 8, 4096, 1]         |      33.0	  |      32.1
      [1, 8, 4096, 8]         |      28.0	  |      27.1
      [1, 8, 4096, 64]        |      41.5	  |      41.9
      [1, 8, 4096, 512]       |      40.6	  |      40.2
      [1, 8, 4096, 4096]      |      64.5	  |      64.7
      [1, 64, 1, 1]           |      26.2	  |      24.4
      [1, 64, 1, 8]           |      26.1	  |      24.4
      [1, 64, 1, 64]          |      26.1	  |      24.5
      [1, 64, 1, 512]         |      26.1	  |      24.7
      [1, 64, 1, 4096]        |      26.1	  |      24.6
      [1, 64, 8, 1]           |      26.1	  |      25.0
      [1, 64, 8, 8]           |      26.0	  |      24.4
      [1, 64, 8, 64]          |      26.1	  |      24.8
      [1, 64, 8, 512]         |      26.3	  |      25.0
      [1, 64, 8, 4096]        |      31.9	  |      30.5
      [1, 64, 64, 1]          |      26.3	  |      24.7
      [1, 64, 64, 8]          |      27.4	  |      26.5
      [1, 64, 64, 64]         |      27.4	  |      26.5
      [1, 64, 64, 512]        |      33.4	  |      32.8
      [1, 64, 64, 4096]       |      31.2	  |      30.5
      [1, 64, 512, 1]         |      26.0	  |      25.0
      [1, 64, 512, 8]         |      28.0	  |      27.1
      [1, 64, 512, 64]        |      41.3	  |      40.8
      [1, 64, 512, 512]       |      34.0	  |      33.5
      [1, 64, 512, 4096]      |      33.1	  |      33.0
      [1, 64, 4096, 1]        |      33.3	  |      32.9
      [1, 64, 4096, 8]        |      39.8	  |      39.9
      [1, 64, 4096, 64]       |      41.5	  |      42.1
      [1, 64, 4096, 512]      |      41.1	  |      40.1
      [1, 64, 4096, 4096]     |      68.3	  |      67.9
      [1, 512, 1, 1]          |      26.0	  |      24.5
      [1, 512, 1, 8]          |      26.1	  |      24.4
      [1, 512, 1, 64]         |      26.3	  |      24.6
      [1, 512, 1, 512]        |      25.9	  |      24.6
      [1, 512, 1, 4096]       |      26.7	  |      25.8
      [1, 512, 8, 1]          |      26.1	  |      24.9
      [1, 512, 8, 8]          |      26.4	  |      24.7
      [1, 512, 8, 64]         |      26.4	  |      24.5
      [1, 512, 8, 512]        |      26.2	  |      24.8
      [1, 512, 8, 4096]       |      32.0	  |      30.8
      [1, 512, 64, 1]         |      25.7	  |      24.5
      [1, 512, 64, 8]         |      26.8	  |      25.0
      [1, 512, 64, 64]        |      32.3	  |      31.9
      [1, 512, 64, 512]       |      32.3	  |      31.5
      [1, 512, 64, 4096]      |      33.0	  |      32.0
      [1, 512, 512, 1]        |      26.2	  |      25.0
      [1, 512, 512, 8]        |      33.3	  |      33.8
      [1, 512, 512, 64]       |      33.9	  |      33.6
      [1, 512, 512, 512]      |      40.4	  |      39.5
      [1, 512, 512, 4096]     |      49.7	  |      50.0
      [1, 512, 4096, 1]       |      39.3	  |      39.8
      [1, 512, 4096, 8]       |      40.2	  |      39.5
      [1, 512, 4096, 64]      |      41.3	  |      40.4
      [1, 512, 4096, 512]     |      49.0	  |      49.2
      [1, 512, 4096, 4096]    |     235.3	  |     235.9
      [1, 4096, 1, 1]         |      26.1	  |      24.8
      [1, 4096, 1, 8]         |      26.4	  |      24.8
      [1, 4096, 1, 64]        |      26.1	  |      24.6
      [1, 4096, 1, 512]       |      28.0	  |      27.2
      [1, 4096, 1, 4096]      |     186.7	  |     187.7
      [1, 4096, 8, 1]         |      25.3	  |      24.8
      [1, 4096, 8, 8]         |      26.7	  |      24.9
      [1, 4096, 8, 64]        |      31.7	  |      30.4
      [1, 4096, 8, 512]       |      32.3	  |      31.4
      [1, 4096, 8, 4096]      |     178.8	  |     178.9
      [1, 4096, 64, 1]        |      26.2	  |      24.9
      [1, 4096, 64, 8]        |      31.4	  |      29.6
      [1, 4096, 64, 64]       |      31.7	  |      30.4
      [1, 4096, 64, 512]      |      33.5	  |      32.3
      [1, 4096, 64, 4096]     |     191.4	  |     191.5
      [1, 4096, 512, 1]       |      29.7	  |      29.0
      [1, 4096, 512, 8]       |      33.4	  |      33.2
      [1, 4096, 512, 64]      |      31.5	  |      30.4
      [1, 4096, 512, 512]     |      51.9	  |      51.8
      [1, 4096, 512, 4096]    |     365.5	  |     366.0
      [1, 4096, 4096, 1]      |      55.3	  |      53.0
      [1, 4096, 4096, 8]      |      85.0	  |      79.1
      [1, 4096, 4096, 64]     |      86.4	  |      82.5
      [1, 4096, 4096, 512]    |     238.0	  |     236.5
      [1, 4096, 4096, 4096]   |    1690.0	  |    1692.7
      [8, 1, 1, 1]            |      26.2	  |      25.1
      [8, 1, 1, 8]            |      25.8	  |      25.0
      [8, 1, 1, 64]           |      25.8	  |      25.0
      [8, 1, 1, 512]          |      26.0	  |      24.9
      [8, 1, 1, 4096]         |      26.1	  |      24.8
      [8, 1, 8, 1]            |      27.6	  |      25.6
      [8, 1, 8, 8]            |      26.3	  |      25.7
      [8, 1, 8, 64]           |      25.8	  |      25.1
      [8, 1, 8, 512]          |      26.5	  |      25.2
      [8, 1, 8, 4096]         |      26.5	  |      25.0
      [8, 1, 64, 1]           |      27.0	  |      25.4
      [8, 1, 64, 8]           |      26.4	  |      25.4
      [8, 1, 64, 64]          |      26.7	  |      25.3
      [8, 1, 64, 512]         |      26.5	  |      25.0
      [8, 1, 64, 4096]        |      26.3	  |      25.4
      [8, 1, 512, 1]          |      32.9	  |      32.2
      [8, 1, 512, 8]          |      26.3	  |      25.3
      [8, 1, 512, 64]         |      26.5	  |      25.1
      [8, 1, 512, 512]        |      26.5	  |      25.1
      [8, 1, 512, 4096]       |      61.3	  |      61.1
      [8, 1, 4096, 1]         |      33.3	  |      31.9
      [8, 1, 4096, 8]         |      26.8	  |      25.5
      [8, 1, 4096, 64]        |      26.6	  |      25.6
      [8, 1, 4096, 512]       |      93.7	  |      93.0
      [8, 1, 4096, 4096]      |     439.9	  |     446.0
      [8, 8, 1, 1]            |      26.9	  |      24.9
      [8, 8, 1, 8]            |      26.4	  |      24.9
      [8, 8, 1, 64]           |      26.6	  |      25.1
      [8, 8, 1, 512]          |      26.4	  |      24.9
      [8, 8, 1, 4096]         |      26.6	  |      25.0
      [8, 8, 8, 1]            |      26.8	  |      25.6
      [8, 8, 8, 8]            |      26.6	  |      25.9
      [8, 8, 8, 64]           |      27.1	  |      26.4
      [8, 8, 8, 512]          |      26.9	  |      26.0
      [8, 8, 8, 4096]         |      26.6	  |      26.6
      [8, 8, 64, 1]           |      27.3	  |      26.2
      [8, 8, 64, 8]           |      26.3	  |      26.7
      [8, 8, 64, 64]          |      26.8	  |      27.2
      [8, 8, 64, 512]         |      27.1	  |      26.8
      [8, 8, 64, 4096]        |      26.9	  |      26.9
      [8, 8, 512, 1]          |      27.0	  |      26.4
      [8, 8, 512, 8]          |      31.5	  |      31.0
      [8, 8, 512, 64]         |      31.6	  |      31.0
      [8, 8, 512, 512]        |      31.8	  |      30.9
      [8, 8, 512, 4096]       |      57.3	  |      57.5
      [8, 8, 4096, 1]         |      27.3	  |      26.7
      [8, 8, 4096, 8]         |      49.7	  |      49.8
      [8, 8, 4096, 64]        |      52.2	  |      52.2
      [8, 8, 4096, 512]       |      88.6	  |      88.9
      [8, 8, 4096, 4096]      |     353.8	  |     352.4
      [8, 64, 1, 1]           |      26.5	  |      26.0
      [8, 64, 1, 8]           |      26.1	  |      26.1
      [8, 64, 1, 64]          |      26.6	  |      25.5
      [8, 64, 1, 512]         |      26.2	  |      25.6
      [8, 64, 1, 4096]        |      27.9	  |      27.1
      [8, 64, 8, 1]           |      26.9	  |      26.6
      [8, 64, 8, 8]           |      31.7	  |      31.3
      [8, 64, 8, 64]          |      31.5	  |      31.2
      [8, 64, 8, 512]         |      31.6	  |      31.1
      [8, 64, 8, 4096]        |      31.8	  |      31.0
      [8, 64, 64, 1]          |      26.7	  |      26.2
      [8, 64, 64, 8]          |      31.8	  |      31.0
      [8, 64, 64, 64]         |      31.6	  |      31.0
      [8, 64, 64, 512]        |      31.6	  |      30.9
      [8, 64, 64, 4096]       |      33.4	  |      33.6
      [8, 64, 512, 1]         |      26.8	  |      26.2
      [8, 64, 512, 8]         |      31.3	  |      31.2
      [8, 64, 512, 64]        |      31.8	  |      31.2
      [8, 64, 512, 512]       |      31.4	  |      30.7
      [8, 64, 512, 4096]      |      80.3	  |      80.4
      [8, 64, 4096, 1]        |      27.1	  |      26.6
      [8, 64, 4096, 8]        |      58.4	  |      58.3
      [8, 64, 4096, 64]       |     101.3	  |     101.4
      [8, 64, 4096, 512]      |     111.4	  |     111.6
      [8, 64, 4096, 4096]     |     423.3	  |     423.3
      [8, 512, 1, 1]          |      26.2	  |      25.9
      [8, 512, 1, 8]          |      26.4	  |      25.8
      [8, 512, 1, 64]         |      26.2	  |      25.8
      [8, 512, 1, 512]        |      28.0	  |      27.3
      [8, 512, 1, 4096]       |     263.4	  |     258.8
      [8, 512, 8, 1]          |      26.7	  |      26.1
      [8, 512, 8, 8]          |      31.2	  |      30.7
      [8, 512, 8, 64]         |      31.6	  |      30.8
      [8, 512, 8, 512]        |      31.8	  |      31.2
      [8, 512, 8, 4096]       |     181.6	  |     182.3
      [8, 512, 64, 1]         |      28.7	  |      25.8
      [8, 512, 64, 8]         |      31.5	  |      30.8
      [8, 512, 64, 64]        |      31.5	  |      30.5
      [8, 512, 64, 512]       |      31.4	  |      30.9
      [8, 512, 64, 4096]      |     199.6	  |     200.0
      [8, 512, 512, 1]        |      27.1	  |      25.8
      [8, 512, 512, 8]        |      31.6	  |      30.8
      [8, 512, 512, 64]       |      31.4	  |      31.0
      [8, 512, 512, 512]      |      53.8	  |      54.0
      [8, 512, 512, 4096]     |     391.1	  |     391.9
      [8, 512, 4096, 1]       |      53.6	  |      53.3
      [8, 512, 4096, 8]       |     129.3	  |     106.8
      [8, 512, 4096, 64]      |     135.1	  |     140.5
      [8, 512, 4096, 512]     |     255.0	  |     255.3
      [8, 512, 4096, 4096]    |    2170.0	  |    2174.1
      [8, 4096, 1, 1]         |      25.9	  |      25.5
      [8, 4096, 1, 8]         |      26.1	  |      25.6
      [8, 4096, 1, 64]        |      27.1	  |      26.0
      [8, 4096, 1, 512]       |     269.1	  |     264.6
      [8, 4096, 1, 4096]      |    2061.4	  |    2023.3
      [8, 4096, 8, 1]         |      26.7	  |      25.9
      [8, 4096, 8, 8]         |      31.5	  |      30.8
      [8, 4096, 8, 64]        |      31.7	  |      31.4
      [8, 4096, 8, 512]       |     179.1	  |     179.5
      [8, 4096, 8, 4096]      |    1387.6	  |    1387.2
      [8, 4096, 64, 1]        |      26.9	  |      26.1
      [8, 4096, 64, 8]        |      31.5	  |      30.7
      [8, 4096, 64, 64]       |      33.5	  |      33.7
      [8, 4096, 64, 512]      |     197.7	  |     198.0
      [8, 4096, 64, 4096]     |    1440.3	  |    1441.0
      [8, 4096, 512, 1]       |      46.8	  |      46.6
      [8, 4096, 512, 8]       |      87.9	  |      86.6
      [8, 4096, 512, 64]      |     108.8	  |     111.0
      [8, 4096, 512, 512]     |     391.4	  |     391.2
      [8, 4096, 512, 4096]    |    2908.1	  |    2897.5
      [8, 4096, 4096, 1]      |     347.0	  |     344.1
      [8, 4096, 4096, 8]      |     685.9	  |     613.0
      [8, 4096, 4096, 64]     |     691.0	  |     674.5
      [8, 4096, 4096, 512]    |    2132.3	  |    2133.2
      [8, 4096, 4096, 4096]   |   17624.9	  |   17630.7
      [64, 1, 1, 1]           |      26.3	  |      25.6
      [64, 1, 1, 8]           |      25.7	  |      25.7
      [64, 1, 1, 64]          |      26.4	  |      25.6
      [64, 1, 1, 512]         |      26.3	  |      25.5
      [64, 1, 1, 4096]        |      26.2	  |      25.9
      [64, 1, 8, 1]           |      26.9	  |      26.7
      [64, 1, 8, 8]           |      26.2	  |      26.0
      [64, 1, 8, 64]          |      26.1	  |      25.7
      [64, 1, 8, 512]         |      26.2	  |      25.7
      [64, 1, 8, 4096]        |      26.5	  |      26.5
      [64, 1, 64, 1]          |      26.9	  |      26.3
      [64, 1, 64, 8]          |      26.1	  |      26.4
      [64, 1, 64, 64]         |      25.8	  |      25.9
      [64, 1, 64, 512]        |      26.2	  |      25.9
      [64, 1, 64, 4096]       |      62.2	  |      62.2
      [64, 1, 512, 1]         |      32.9	  |      33.7
      [64, 1, 512, 8]         |      26.4	  |      26.2
      [64, 1, 512, 64]        |      25.8	  |      26.1
      [64, 1, 512, 512]       |      54.8	  |      55.4
      [64, 1, 512, 4096]      |     374.6	  |     374.5
      [64, 1, 4096, 1]        |      33.2	  |      33.2
      [64, 1, 4096, 8]        |      33.6	  |      33.4
      [64, 1, 4096, 64]       |      76.2	  |      77.9
      [64, 1, 4096, 512]      |     439.2	  |     423.2
      [64, 1, 4096, 4096]     |    2926.5	  |    2970.2
      [64, 8, 1, 1]           |      26.5	  |      26.0
      [64, 8, 1, 8]           |      26.4	  |      25.9
      [64, 8, 1, 64]          |      26.3	  |      25.9
      [64, 8, 1, 512]         |      26.3	  |      26.0
      [64, 8, 1, 4096]        |      36.9	  |      36.8
      [64, 8, 8, 1]           |      26.9	  |      26.1
      [64, 8, 8, 8]           |      27.2	  |      26.8
      [64, 8, 8, 64]          |      26.8	  |      26.8
      [64, 8, 8, 512]         |      27.0	  |      26.4
      [64, 8, 8, 4096]        |      60.6	  |      64.4
      [64, 8, 64, 1]          |      27.0	  |      26.4
      [64, 8, 64, 8]          |      26.9	  |      27.0
      [64, 8, 64, 64]         |      27.2	  |      26.9
      [64, 8, 64, 512]        |      26.8	  |      26.8
      [64, 8, 64, 4096]       |     111.4	  |     117.3
      [64, 8, 512, 1]         |      26.9	  |      26.4
      [64, 8, 512, 8]         |      31.2	  |      31.3
      [64, 8, 512, 64]        |      31.6	  |      30.9
      [64, 8, 512, 512]       |      61.0	  |      60.8
      [64, 8, 512, 4096]      |     380.9	  |     380.9
      [64, 8, 4096, 1]        |      27.4	  |      26.6
      [64, 8, 4096, 8]        |      78.9	  |      79.3
      [64, 8, 4096, 64]       |      96.1	  |      97.9
      [64, 8, 4096, 512]      |     361.9	  |     362.2
      [64, 8, 4096, 4096]     |    2567.1	  |    2569.0
      [64, 64, 1, 1]          |      26.1	  |      26.1
      [64, 64, 1, 8]          |      26.7	  |      26.0
      [64, 64, 1, 64]         |      26.5	  |      25.7
      [64, 64, 1, 512]        |      28.3	  |      27.7
      [64, 64, 1, 4096]       |     263.2	  |     258.3
      [64, 64, 8, 1]          |      26.8	  |      26.3
      [64, 64, 8, 8]          |      31.5	  |      31.3
      [64, 64, 8, 64]         |      31.8	  |      31.2
      [64, 64, 8, 512]        |      31.7	  |      31.1
      [64, 64, 8, 4096]       |     187.0	  |     186.8
      [64, 64, 64, 1]         |      27.1	  |      26.5
      [64, 64, 64, 8]         |      31.7	  |      31.1
      [64, 64, 64, 64]        |      32.3	  |      31.4
      [64, 64, 64, 512]       |      34.4	  |      34.6
      [64, 64, 64, 4096]      |     228.8	  |     229.2
      [64, 64, 512, 1]        |      26.6	  |      26.5
      [64, 64, 512, 8]        |      31.5	  |      30.7
      [64, 64, 512, 64]       |      31.9	  |      31.1
      [64, 64, 512, 512]      |      88.1	  |      88.4
      [64, 64, 512, 4096]     |     574.7	  |     573.8
      [64, 64, 4096, 1]       |      54.3	  |      53.6
      [64, 64, 4096, 8]       |     131.3	  |     110.8
      [64, 64, 4096, 64]      |     154.2	  |     161.6
      [64, 64, 4096, 512]     |     471.8	  |     471.9
      [64, 64, 4096, 4096]    |    3023.7	  |    3029.6
      [64, 512, 1, 1]         |      25.9	  |      25.4
      [64, 512, 1, 8]         |      26.3	  |      25.7
      [64, 512, 1, 64]        |      27.1	  |      26.2
      [64, 512, 1, 512]       |     268.9	  |     264.4
      [64, 512, 1, 4096]      |    2061.7	  |    2022.8
      [64, 512, 8, 1]         |      26.5	  |      26.0
      [64, 512, 8, 8]         |      31.3	  |      30.7
      [64, 512, 8, 64]        |      31.6	  |      31.0
      [64, 512, 8, 512]       |     179.5	  |     180.2
      [64, 512, 8, 4096]      |    1392.8	  |    1392.7
      [64, 512, 64, 1]        |      27.0	  |      26.6
      [64, 512, 64, 8]        |      32.0	  |      30.7
      [64, 512, 64, 64]       |      35.1	  |      35.0
      [64, 512, 64, 512]      |     200.1	  |     201.7
      [64, 512, 64, 4096]     |    1496.1	  |    1495.1
      [64, 512, 512, 1]       |      48.6	  |      48.8
      [64, 512, 512, 8]       |      86.9	  |      87.9
      [64, 512, 512, 64]      |     113.1	  |     115.6
      [64, 512, 512, 512]     |     407.4	  |     406.8
      [64, 512, 512, 4096]    |    3008.6	  |    3011.1
      [64, 512, 4096, 1]      |     345.1	  |     344.1
      [64, 512, 4096, 8]      |     646.0	  |     624.6
      [64, 512, 4096, 64]     |     768.1	  |     741.7
      [64, 512, 4096, 512]    |    2205.5	  |    2200.1
      [64, 512, 4096, 4096]   |   17672.4	  |   17650.1
      [64, 4096, 1, 1]        |      26.4	  |      25.3
      [64, 4096, 1, 8]        |      53.5	  |      51.5
      [64, 4096, 1, 64]       |     243.1	  |     239.1
      [64, 4096, 1, 512]      |    2110.5	  |    2071.7
      [64, 4096, 1, 4096]     |   16466.1	  |   16161.7
      [64, 4096, 8, 1]        |      60.5	  |      60.6
      [64, 4096, 8, 8]        |      62.6	  |      62.5
      [64, 4096, 8, 64]       |     186.0	  |     186.2
      [64, 4096, 8, 512]      |    1369.4	  |    1369.6
      [64, 4096, 8, 4096]     |   11022.3	  |   11024.8
      [64, 4096, 64, 1]       |      91.9	  |      92.1
      [64, 4096, 64, 8]       |      89.3	  |      89.4
      [64, 4096, 64, 64]      |     223.1	  |     223.5
      [64, 4096, 64, 512]     |    1470.7	  |    1470.6
      [64, 4096, 64, 4096]    |   11435.7	  |   11432.6
      [64, 4096, 512, 1]      |     327.1	  |     327.3
      [64, 4096, 512, 8]      |     539.3	  |     539.8
      [64, 4096, 512, 64]     |     714.9	  |     713.1
      [64, 4096, 512, 512]    |    3056.6	  |    3054.6
      [64, 4096, 512, 4096]   |   23219.5	  |   23314.5
      [64, 4096, 4096, 1]     |    2612.9	  |    2607.8
      [64, 4096, 4096, 8]     |    4898.4	  |    4695.7
      [64, 4096, 4096, 64]    |    5141.5	  |    5099.3
      [64, 4096, 4096, 512]   |   16922.7	  |   17301.9
      [64, 4096, 4096, 4096]  |  141546.0	  |  141532.3
      [512, 1, 1, 1]          |      26.2	  |      25.0
      [512, 1, 1, 8]          |      26.5	  |      25.0
      [512, 1, 1, 64]         |      26.0	  |      25.2
      [512, 1, 1, 512]        |      26.2	  |      25.6
      [512, 1, 1, 4096]       |     161.3	  |     164.6
      [512, 1, 8, 1]          |      26.4	  |      25.5
      [512, 1, 8, 8]          |      26.2	  |      25.1
      [512, 1, 8, 64]         |      26.2	  |      25.3
      [512, 1, 8, 512]        |      34.8	  |      34.8
      [512, 1, 8, 4096]       |     214.0	  |     213.8
      [512, 1, 64, 1]         |      27.3	  |      25.6
      [512, 1, 64, 8]         |      26.3	  |      25.1
      [512, 1, 64, 64]        |      26.2	  |      25.2
      [512, 1, 64, 512]       |      74.0	  |      74.1
      [512, 1, 64, 4096]      |     452.0	  |     451.1
      [512, 1, 512, 1]        |      33.2	  |      32.4
      [512, 1, 512, 8]        |      26.3	  |      25.2
      [512, 1, 512, 64]       |      53.3	  |      53.0
      [512, 1, 512, 512]      |     384.0	  |     383.6
      [512, 1, 512, 4096]     |    2934.2	  |    2931.1
      [512, 1, 4096, 1]       |      32.4	  |      31.9
      [512, 1, 4096, 8]       |     128.4	  |     134.5
      [512, 1, 4096, 64]      |     385.2	  |     384.4
      [512, 1, 4096, 512]     |    3450.5	  |    3411.9
      [512, 8, 1, 1]          |      26.2	  |      25.2
      [512, 8, 1, 8]          |      26.4	  |      24.7
      [512, 8, 1, 64]         |      26.6	  |      24.9
      [512, 8, 1, 512]        |      37.6	  |      37.6
      [512, 8, 1, 4096]       |     314.0	  |     313.0
      [512, 8, 8, 1]          |      27.2	  |      25.2
      [512, 8, 8, 8]          |      27.3	  |      26.8
      [512, 8, 8, 64]         |      27.4	  |      26.0
      [512, 8, 8, 512]        |      62.4	  |      66.4
      [512, 8, 8, 4096]       |     455.1	  |     485.9
      [512, 8, 64, 1]         |      28.4	  |      25.4
      [512, 8, 64, 8]         |      28.9	  |      25.9
      [512, 8, 64, 64]        |      28.9	  |      26.2
      [512, 8, 64, 512]       |     115.4	  |     121.9
      [512, 8, 64, 4096]      |     827.1	  |     877.0
      [512, 8, 512, 1]        |      28.1	  |      25.3
      [512, 8, 512, 8]        |      42.3	  |      41.8
      [512, 8, 512, 64]       |      62.6	  |      62.7
      [512, 8, 512, 512]      |     405.3	  |     405.1
      [512, 8, 512, 4096]     |    2947.5	  |    2949.7
      [512, 8, 4096, 1]       |      57.4	  |      57.4
      [512, 8, 4096, 8]       |     281.5	  |     282.3
      [512, 8, 4096, 64]      |     415.0	  |     417.0
      [512, 8, 4096, 512]     |    2741.3	  |    2732.7
      [512, 64, 1, 1]         |      27.4	  |      25.1
      [512, 64, 1, 8]         |      28.0	  |      25.2
      [512, 64, 1, 64]        |      27.9	  |      26.4
      [512, 64, 1, 512]       |     269.0	  |     264.5
      [512, 64, 1, 4096]      |    2060.3	  |    2019.9
      [512, 64, 8, 1]         |      28.6	  |      25.4
      [512, 64, 8, 8]         |      33.8	  |      30.2
      [512, 64, 8, 64]        |      34.1	  |      30.3
      [512, 64, 8, 512]       |     185.3	  |     185.7
      [512, 64, 8, 4096]      |    1431.9	  |    1432.1
      [512, 64, 64, 1]        |      28.5	  |      25.7
      [512, 64, 64, 8]        |      33.3	  |      30.0
      [512, 64, 64, 64]       |      39.4	  |      39.5
      [512, 64, 64, 512]      |     234.0	  |     235.0
      [512, 64, 64, 4096]     |    1781.3	  |    1781.2
      [512, 64, 512, 1]       |      48.6	  |      49.3
      [512, 64, 512, 8]       |      68.5	  |      68.4
      [512, 64, 512, 64]      |     123.4	  |     119.1
      [512, 64, 512, 512]     |     635.3	  |     630.3
      [512, 64, 512, 4096]    |    4619.5	  |    4618.3
      [512, 64, 4096, 1]      |     351.6	  |     349.5
      [512, 64, 4096, 8]      |     626.0	  |     590.3
      [512, 64, 4096, 64]     |     872.8	  |     879.1
      [512, 64, 4096, 512]    |    3493.7	  |    3492.3
      [512, 512, 1, 1]        |      27.4	  |      25.4
      [512, 512, 1, 8]        |      53.6	  |      51.7
      [512, 512, 1, 64]       |     243.5	  |     239.2
      [512, 512, 1, 512]      |    2110.1	  |    2072.3
      [512, 512, 1, 4096]     |   16433.4	  |   16125.2
      [512, 512, 8, 1]        |      64.5	  |      64.9
      [512, 512, 8, 8]        |      62.6	  |      62.6
      [512, 512, 8, 64]       |     186.8	  |     187.0
      [512, 512, 8, 512]      |    1374.4	  |    1374.2
      [512, 512, 8, 4096]     |   11040.0	  |   11040.2
      [512, 512, 64, 1]       |     102.0	  |     102.1
      [512, 512, 64, 8]       |      89.7	  |      89.8
      [512, 512, 64, 64]      |     229.1	  |     228.8
      [512, 512, 64, 512]     |    1502.7	  |    1500.9
      [512, 512, 64, 4096]    |   11851.2	  |   11843.4
      [512, 512, 512, 1]      |     339.4	  |     340.0
      [512, 512, 512, 8]      |     538.9	  |     541.3
      [512, 512, 512, 64]     |     727.2	  |     721.5
      [512, 512, 512, 512]    |    3172.9	  |    3167.4
      [512, 512, 512, 4096]   |   24132.4	  |   24217.9
      [512, 512, 4096, 1]     |    2631.0	  |    2630.8
      [512, 512, 4096, 8]     |    4987.9	  |    4868.8
      [512, 512, 4096, 64]    |    5644.2	  |    5387.8
      [512, 512, 4096, 512]   |   17513.5	  |   17515.1
      [512, 4096, 1, 1]       |     163.9	  |     166.8
      [512, 4096, 1, 8]       |     438.4	  |     425.3
      [512, 4096, 1, 64]      |    1897.8	  |    1863.1
      [512, 4096, 1, 512]     |   16822.7	  |   16524.6
      [512, 4096, 8, 1]       |     602.6	  |     602.0
      [512, 4096, 8, 8]       |     462.5	  |     461.3
      [512, 4096, 8, 64]      |    1422.4	  |    1422.5
      [512, 4096, 8, 512]     |   10863.0	  |   10876.8
      [512, 4096, 64, 1]      |     700.9	  |     702.2
      [512, 4096, 64, 8]      |     627.6	  |     627.9
      [512, 4096, 64, 64]     |    1719.6	  |    1719.4
      [512, 4096, 64, 512]    |   11623.8	  |   11622.9
      [512, 4096, 512, 1]     |    2568.9	  |    2568.7
      [512, 4096, 512, 8]     |    4189.9	  |    4195.5
      [512, 4096, 512, 64]    |    6102.8	  |    6071.0
      [512, 4096, 512, 512]   |   24647.3	  |   24627.4
	
Times are in microseconds (us).	s).
	

and A100:

[---------- bench_gemm ----------]	-----]
                          |   old	  new
1 threads: -----------------------	------
      [1, 1, 1]           |   20.8	  20.1
      [1, 1, 8]           |    9.5	   9.0
      [1, 1, 64]          |    9.6	   9.1
      [1, 1, 512]         |    9.1	   9.1
      [1, 1, 4096]        |    9.3	   8.9
      [1, 8, 1]           |    8.8	   8.8
      [1, 8, 8]           |    8.6	   8.8
      [1, 8, 64]          |    8.9	   8.6
      [1, 8, 512]         |    9.0	   9.1
      [1, 8, 4096]        |    9.3	   8.8
      [1, 64, 1]          |    9.1	   8.9
      [1, 64, 8]          |    8.8	   9.0
      [1, 64, 64]         |    8.9	   9.1
      [1, 64, 512]        |   29.7	  30.1
      [1, 64, 4096]       |    9.1	   9.1
      [1, 512, 1]         |   14.3	  14.4
      [1, 512, 8]         |   14.9	  15.0
      [1, 512, 64]        |   15.3	  14.9
      [1, 512, 512]       |    8.8	   8.9
      [1, 512, 4096]      |   15.9	  16.2
      [1, 4096, 1]        |   14.5	  14.8
      [1, 4096, 8]        |   14.6	  14.6
      [1, 4096, 64]       |   14.9	  14.1
      [1, 4096, 512]      |   20.5	  19.7
      [1, 4096, 4096]     |   20.2	  21.1
      [8, 1, 1]           |    9.3	   9.2
      [8, 1, 8]           |    8.9	   8.9
      [8, 1, 64]          |    9.4	   8.7
      [8, 1, 512]         |    9.5	   8.8
      [8, 1, 4096]        |    9.2	   9.0
      [8, 8, 1]           |    9.2	   8.6
      [8, 8, 8]           |    9.5	   9.2
      [8, 8, 64]          |    9.5	   9.3
      [8, 8, 512]         |    9.6	   9.3
      [8, 8, 4096]        |    9.1	   8.9
      [8, 64, 1]          |    9.1	   8.6
      [8, 64, 8]          |    9.2	   9.3
      [8, 64, 64]         |   10.4	   9.5
      [8, 64, 512]        |    9.5	   9.5
      [8, 64, 4096]       |    9.4	   9.5
      [8, 512, 1]         |   14.9	  14.8
      [8, 512, 8]         |    9.0	   9.0
      [8, 512, 64]        |    9.1	   9.0
      [8, 512, 512]       |   11.5	  12.1
      [8, 512, 4096]      |    9.1	   9.0
      [8, 4096, 1]        |   14.9	  14.6
      [8, 4096, 8]        |   14.4	  13.9
      [8, 4096, 64]       |   22.7	  23.6
      [8, 4096, 512]      |   16.7	  17.7
      [8, 4096, 4096]     |   20.2	  20.6
      [64, 1, 1]          |    9.1	   8.6
      [64, 1, 8]          |    9.1	   8.7
      [64, 1, 64]         |    9.1	   8.9
      [64, 1, 512]        |    9.3	   9.5
      [64, 1, 4096]       |    9.1	   9.1
      [64, 8, 1]          |    9.0	   8.8
      [64, 8, 8]          |    9.5	   9.5
      [64, 8, 64]         |    9.6	   9.7
      [64, 8, 512]        |    9.6	   9.1
      [64, 8, 4096]       |   13.2	  12.9
      [64, 64, 1]         |    9.0	   8.8
      [64, 64, 8]         |    9.6	   9.6
      [64, 64, 64]        |    9.3	   9.1
      [64, 64, 512]       |   16.2	  16.1
      [64, 64, 4096]      |   13.2	  13.0
      [64, 512, 1]        |   15.1	  15.2
      [64, 512, 8]        |    9.4	   9.1
      [64, 512, 64]       |   10.0	   9.3
      [64, 512, 512]      |   13.7	  13.7
      [64, 512, 4096]     |   11.8	  12.0
      [64, 4096, 1]       |   14.1	  14.1
      [64, 4096, 8]       |   23.9	  23.7
      [64, 4096, 64]      |   23.2	  23.3
      [64, 4096, 512]     |   23.1	  23.8
      [64, 4096, 4096]    |   26.6	  27.1
      [512, 1, 1]         |    9.5	   9.0
      [512, 1, 8]         |    9.4	   9.2
      [512, 1, 64]        |    9.3	   9.4
      [512, 1, 512]       |    9.3	   8.7
      [512, 1, 4096]      |   10.5	  10.4
      [512, 8, 1]         |    9.1	   8.9
      [512, 8, 8]         |    9.4	   9.1
      [512, 8, 64]        |    9.7	   9.3
      [512, 8, 512]       |   10.3	   9.8
      [512, 8, 4096]      |   12.9	  12.9
      [512, 64, 1]        |    9.1	   9.0
      [512, 64, 8]        |    9.4	   9.2
      [512, 64, 64]       |   16.0	  16.6
      [512, 64, 512]      |   14.7	  14.8
      [512, 64, 4096]     |   12.2	  12.2
      [512, 512, 1]       |    8.9	   9.0
      [512, 512, 8]       |   16.5	  16.6
      [512, 512, 64]      |   13.1	  12.9
      [512, 512, 512]     |   15.9	  15.8
      [512, 512, 4096]    |   23.1	  23.0
      [512, 4096, 1]      |   15.1	  15.3
      [512, 4096, 8]      |   22.1	  22.7
      [512, 4096, 64]     |   22.3	  22.9
      [512, 4096, 512]    |   24.0	  24.0
      [512, 4096, 4096]   |  111.9	 120.7
      [4096, 1, 1]        |    9.4	   8.8
      [4096, 1, 8]        |    9.3	   8.9
      [4096, 1, 64]       |    9.3	   8.9
      [4096, 1, 512]      |   12.1	  12.0
      [4096, 1, 4096]     |   41.3	  41.2
      [4096, 8, 1]        |    8.9	   8.9
      [4096, 8, 8]        |    9.6	  15.0
      [4096, 8, 64]       |   15.9	  16.1
      [4096, 8, 512]      |   15.6	  16.0
      [4096, 8, 4096]     |   31.1	  31.2
      [4096, 64, 1]       |    9.2	   9.2
      [4096, 64, 8]       |   20.3	  21.0
      [4096, 64, 64]      |   15.9	  16.1
      [4096, 64, 512]     |   12.8	  13.0
      [4096, 64, 4096]    |   39.6	  39.6
      [4096, 512, 1]      |   14.6	  16.0
      [4096, 512, 8]      |   20.3	  20.9
      [4096, 512, 64]     |   15.5	  16.5
      [4096, 512, 512]    |   19.5	  19.4
      [4096, 512, 4096]   |  100.5	 100.2
      [4096, 4096, 1]     |   24.3	  24.9
      [4096, 4096, 8]     |   25.0	  24.5
      [4096, 4096, 64]    |   29.7	  30.6
      [4096, 4096, 512]   |  107.9	  94.5
      [4096, 4096, 4096]  |  687.9	 699.1
	
Times are in microseconds (us).	s).
	
[----------- bench_baddbmm ------------]	-----------]
                              |    old            |     new
1 threads: -----------------------------	------------
      [1, 1, 1, 1]            |     18.7	  |     17.4
      [1, 1, 1, 8]            |     18.3	  |     17.4
      [1, 1, 1, 64]           |     18.4	  |     17.0
      [1, 1, 1, 512]          |     18.6	  |     17.4
      [1, 1, 1, 4096]         |     18.6	  |     17.3
      [1, 1, 8, 1]            |     18.5	  |     16.9
      [1, 1, 8, 8]            |     18.4	  |     16.6
      [1, 1, 8, 64]           |     18.4	  |     17.1
      [1, 1, 8, 512]          |     18.5	  |     17.4
      [1, 1, 8, 4096]         |     18.7	  |     17.9
      [1, 1, 64, 1]           |     18.3	  |     16.9
      [1, 1, 64, 8]           |     18.3	  |     16.9
      [1, 1, 64, 64]          |     18.1	  |     16.7
      [1, 1, 64, 512]         |     38.9	  |     38.5
      [1, 1, 64, 4096]        |     18.6	  |     17.6
      [1, 1, 512, 1]          |     24.3	  |     24.6
      [1, 1, 512, 8]          |     25.1	  |     24.4
      [1, 1, 512, 64]         |     25.8	  |     23.7
      [1, 1, 512, 512]        |     18.5	  |     16.9
      [1, 1, 512, 4096]       |     26.6	  |     26.5
      [1, 1, 4096, 1]         |     26.2	  |     23.7
      [1, 1, 4096, 8]         |     24.4	  |     23.5
      [1, 1, 4096, 64]        |     25.7	  |     23.9
      [1, 1, 4096, 512]       |     30.4	  |     30.7
      [1, 1, 4096, 4096]      |     26.0	  |     26.0
      [1, 8, 1, 1]            |     18.6	  |     17.4
      [1, 8, 1, 8]            |     18.4	  |     16.8
      [1, 8, 1, 64]           |     18.6	  |     17.2
      [1, 8, 1, 512]          |     18.6	  |     17.0
      [1, 8, 1, 4096]         |     18.8	  |     17.1
      [1, 8, 8, 1]            |     18.4	  |     16.8
      [1, 8, 8, 8]            |     18.6	  |     17.3
      [1, 8, 8, 64]           |     18.5	  |     17.2
      [1, 8, 8, 512]          |     18.9	  |     17.5
      [1, 8, 8, 4096]         |     19.1	  |     17.6
      [1, 8, 64, 1]           |     19.2	  |     17.5
      [1, 8, 64, 8]           |     18.8	  |     17.4
      [1, 8, 64, 64]          |     18.7	  |     17.6
      [1, 8, 64, 512]         |     19.0	  |     17.7
      [1, 8, 64, 4096]        |     18.7	  |     17.6
      [1, 8, 512, 1]          |     24.6	  |     23.6
      [1, 8, 512, 8]          |     19.0	  |     17.7
      [1, 8, 512, 64]         |     18.9	  |     17.8
      [1, 8, 512, 512]        |     21.8	  |     21.9
      [1, 8, 512, 4096]       |     19.2	  |     17.7
      [1, 8, 4096, 1]         |     24.7	  |     23.8
      [1, 8, 4096, 8]         |     26.2	  |     24.6
      [1, 8, 4096, 64]        |     33.8	  |     34.7
      [1, 8, 4096, 512]       |     29.3	  |     29.4
      [1, 8, 4096, 4096]      |     28.5	  |     26.2
      [1, 64, 1, 1]           |     18.6	  |     17.2
      [1, 64, 1, 8]           |     20.1	  |     17.2
      [1, 64, 1, 64]          |     18.9	  |     17.2
      [1, 64, 1, 512]         |     18.6	  |     17.3
      [1, 64, 1, 4096]        |     19.0	  |     17.2
      [1, 64, 8, 1]           |     18.7	  |     17.2
      [1, 64, 8, 8]           |     19.0	  |     17.2
      [1, 64, 8, 64]          |     19.6	  |     17.5
      [1, 64, 8, 512]         |     20.2	  |     17.7
      [1, 64, 8, 4096]        |     23.4	  |     22.6
      [1, 64, 64, 1]          |     18.5	  |     17.1
      [1, 64, 64, 8]          |     18.6	  |     17.5
      [1, 64, 64, 64]         |     20.4	  |     17.4
      [1, 64, 64, 512]        |     26.7	  |     26.4
      [1, 64, 64, 4096]       |     23.1	  |     23.0
      [1, 64, 512, 1]         |     25.1	  |     23.7
      [1, 64, 512, 8]         |     19.1	  |     17.8
      [1, 64, 512, 64]        |     18.9	  |     18.0
      [1, 64, 512, 512]       |     24.0	  |     23.2
      [1, 64, 512, 4096]      |     22.1	  |     21.3
      [1, 64, 4096, 1]        |     25.7	  |     24.3
      [1, 64, 4096, 8]        |     34.4	  |     33.8
      [1, 64, 4096, 64]       |     34.2	  |     34.1
      [1, 64, 4096, 512]      |     34.4	  |     34.1
      [1, 64, 4096, 4096]     |     31.2	  |     30.4
      [1, 512, 1, 1]          |     18.4	  |     16.8
      [1, 512, 1, 8]          |     18.4	  |     17.0
      [1, 512, 1, 64]         |     18.8	  |     17.2
      [1, 512, 1, 512]        |     18.7	  |     17.1
      [1, 512, 1, 4096]       |     18.7	  |     18.3
      [1, 512, 8, 1]          |     18.4	  |     17.1
      [1, 512, 8, 8]          |     18.9	  |     17.4
      [1, 512, 8, 64]         |     20.4	  |     17.7
      [1, 512, 8, 512]        |     19.7	  |     18.8
      [1, 512, 8, 4096]       |     23.5	  |     22.4
      [1, 512, 64, 1]         |     18.5	  |     17.1
      [1, 512, 64, 8]         |     19.4	  |     17.5
      [1, 512, 64, 64]        |     27.8	  |     26.5
      [1, 512, 64, 512]       |     24.6	  |     24.8
      [1, 512, 64, 4096]      |     22.0	  |     21.2
      [1, 512, 512, 1]        |     18.5	  |     17.2
      [1, 512, 512, 8]        |     27.4	  |     27.3
      [1, 512, 512, 64]       |     23.0	  |     22.5
      [1, 512, 512, 512]      |     26.1	  |     26.2
      [1, 512, 512, 4096]     |     31.6	  |     31.3
      [1, 512, 4096, 1]       |     25.3	  |     24.1
      [1, 512, 4096, 8]       |     33.4	  |     33.9
      [1, 512, 4096, 64]      |     33.0	  |     33.3
      [1, 512, 4096, 512]     |     33.6	  |     31.3
      [1, 512, 4096, 4096]    |    114.8	  |    113.3
      [1, 4096, 1, 1]         |     18.7	  |     16.9
      [1, 4096, 1, 8]         |     18.7	  |     17.2
      [1, 4096, 1, 64]        |     18.7	  |     16.9
      [1, 4096, 1, 512]       |     20.0	  |     20.0
      [1, 4096, 1, 4096]      |    105.8	  |    105.4
      [1, 4096, 8, 1]         |     18.4	  |     16.6
      [1, 4096, 8, 8]         |     19.2	  |     18.0
      [1, 4096, 8, 64]        |     26.4	  |     25.6
      [1, 4096, 8, 512]       |     26.0	  |     25.2
      [1, 4096, 8, 4096]      |     93.7	  |     93.6
      [1, 4096, 64, 1]        |     18.9	  |     17.0
      [1, 4096, 64, 8]        |     32.0	  |     31.5
      [1, 4096, 64, 64]       |     26.6	  |     25.9
      [1, 4096, 64, 512]      |     23.2	  |     22.4
      [1, 4096, 64, 4096]     |    100.7	  |    100.5
      [1, 4096, 512, 1]       |     26.2	  |     24.3
      [1, 4096, 512, 8]       |     32.4	  |     31.5
      [1, 4096, 512, 64]      |     27.4	  |     25.5
      [1, 4096, 512, 512]     |     27.2	  |     26.9
      [1, 4096, 512, 4096]    |    156.7	  |    155.8
      [1, 4096, 4096, 1]      |     28.2	  |     27.6
      [1, 4096, 4096, 8]      |     28.0	  |     29.5
      [1, 4096, 4096, 64]     |     33.8	  |     32.8
      [1, 4096, 4096, 512]    |    115.6	  |    115.5
      [1, 4096, 4096, 4096]   |    727.2	  |    729.8
      [8, 1, 1, 1]            |     19.0	  |     17.6
      [8, 1, 1, 8]            |     19.1	  |     17.9
      [8, 1, 1, 64]           |     21.1	  |     17.1
      [8, 1, 1, 512]          |     18.5	  |     17.1
      [8, 1, 1, 4096]         |     18.3	  |     17.4
      [8, 1, 8, 1]            |     19.4	  |     17.7
      [8, 1, 8, 8]            |     18.6	  |     17.1
      [8, 1, 8, 64]           |     18.4	  |     16.9
      [8, 1, 8, 512]          |     18.8	  |     17.4
      [8, 1, 8, 4096]         |     18.6	  |     17.1
      [8, 1, 64, 1]           |     19.1	  |     17.5
      [8, 1, 64, 8]           |     18.3	  |     17.0
      [8, 1, 64, 64]          |     19.1	  |     17.6
      [8, 1, 64, 512]         |     18.6	  |     17.9
      [8, 1, 64, 4096]        |     18.9	  |     18.0
      [8, 1, 512, 1]          |     32.7	  |     32.7
      [8, 1, 512, 8]          |     18.9	  |     17.3
      [8, 1, 512, 64]         |     18.6	  |     17.1
      [8, 1, 512, 512]        |     18.6	  |     17.3
      [8, 1, 512, 4096]       |     36.6	  |     36.7
      [8, 1, 4096, 1]         |    155.4	  |    155.3
      [8, 1, 4096, 8]         |     18.4	  |     16.8
      [8, 1, 4096, 64]        |     33.2	  |     33.5
      [8, 1, 4096, 512]       |     94.2	  |     94.5
      [8, 1, 4096, 4096]      |    259.7	  |    261.1
      [8, 8, 1, 1]            |     19.3	  |     17.8
      [8, 8, 1, 8]            |     19.2	  |     18.5
      [8, 8, 1, 64]           |     19.4	  |     17.7
      [8, 8, 1, 512]          |     18.4	  |     17.4
      [8, 8, 1, 4096]         |     18.7	  |     17.7
      [8, 8, 8, 1]            |     19.2	  |     17.9
      [8, 8, 8, 8]            |     18.5	  |     17.1
      [8, 8, 8, 64]           |     18.7	  |     17.2
      [8, 8, 8, 512]          |     19.1	  |     17.6
      [8, 8, 8, 4096]         |     19.2	  |     17.6
      [8, 8, 64, 1]           |     19.4	  |     17.7
      [8, 8, 64, 8]           |     18.5	  |     17.3
      [8, 8, 64, 64]          |     19.1	  |     17.6
      [8, 8, 64, 512]         |     19.5	  |     17.8
      [8, 8, 64, 4096]        |     19.6	  |     18.0
      [8, 8, 512, 1]          |     32.7	  |     32.9
      [8, 8, 512, 8]          |     19.0	  |     17.5
      [8, 8, 512, 64]         |     18.8	  |     17.2
      [8, 8, 512, 512]        |     24.3	  |     22.6
      [8, 8, 512, 4096]       |     30.9	  |     29.6
      [8, 8, 4096, 1]         |    156.6	  |    156.7
      [8, 8, 4096, 8]         |     33.2	  |     33.4
      [8, 8, 4096, 64]        |     22.3	  |     22.4
      [8, 8, 4096, 512]       |     27.0	  |     26.6
      [8, 8, 4096, 4096]      |    183.0	  |    182.9
      [8, 64, 1, 1]           |     19.2	  |     17.7
      [8, 64, 1, 8]           |     19.3	  |     17.8
      [8, 64, 1, 64]          |     19.2	  |     17.9
      [8, 64, 1, 512]         |     18.8	  |     17.2
      [8, 64, 1, 4096]        |     18.5	  |     17.3
      [8, 64, 8, 1]           |     21.0	  |     17.7
      [8, 64, 8, 8]           |     18.9	  |     17.2
      [8, 64, 8, 64]          |     18.8	  |     17.4
      [8, 64, 8, 512]         |     24.0	  |     22.5
      [8, 64, 8, 4096]        |     24.3	  |     22.5
      [8, 64, 64, 1]          |     19.5	  |     17.7
      [8, 64, 64, 8]          |     18.9	  |     17.5
      [8, 64, 64, 64]         |     18.9	  |     17.4
      [8, 64, 64, 512]        |     24.3	  |     22.5
      [8, 64, 64, 4096]       |     25.8	  |     22.6
      [8, 64, 512, 1]         |     38.1	  |     38.2
      [8, 64, 512, 8]         |     19.3	  |     17.7
      [8, 64, 512, 64]        |     18.9	  |     17.4
      [8, 64, 512, 512]       |     24.4	  |     22.5
      [8, 64, 512, 4096]      |     45.0	  |     45.2
      [8, 64, 4096, 1]        |    201.1	  |    201.2
      [8, 64, 4096, 8]        |     29.6	  |     29.9
      [8, 64, 4096, 64]       |     22.6	  |     22.5
      [8, 64, 4096, 512]      |     33.8	  |     33.8
      [8, 64, 4096, 4096]     |    192.2	  |    193.5
      [8, 512, 1, 1]          |     18.2	  |     17.5
      [8, 512, 1, 8]          |     19.0	  |     17.4
      [8, 512, 1, 64]         |     18.9	  |     17.5
      [8, 512, 1, 512]        |     18.7	  |     17.3
      [8, 512, 1, 4096]       |    183.9	  |    180.9
      [8, 512, 8, 1]          |     18.8	  |     17.6
      [8, 512, 8, 8]          |     24.1	  |     23.0
      [8, 512, 8, 64]         |     24.0	  |     22.8
      [8, 512, 8, 512]        |     24.2	  |     22.9
      [8, 512, 8, 4096]       |     99.0	  |     99.0
      [8, 512, 64, 1]         |     18.6	  |     17.0
      [8, 512, 64, 8]         |     23.8	  |     22.4
      [8, 512, 64, 64]        |     24.1	  |     22.5
      [8, 512, 64, 512]       |     24.5	  |     22.6
      [8, 512, 64, 4096]      |    105.6	  |    105.4
      [8, 512, 512, 1]        |     19.6	  |     18.2
      [8, 512, 512, 8]        |     24.6	  |     22.5
      [8, 512, 512, 64]       |     24.2	  |     22.5
      [8, 512, 512, 512]      |     28.7	  |     28.7
      [8, 512, 512, 4096]     |    161.1	  |    160.9
      [8, 512, 4096, 1]       |     37.4	  |     36.2
      [8, 512, 4096, 8]       |     27.8	  |     27.5
      [8, 512, 4096, 64]      |     34.2	  |     34.2
      [8, 512, 4096, 512]     |    134.3	  |    134.7
      [8, 512, 4096, 4096]    |    906.8	  |    900.6
      [8, 4096, 1, 1]         |     18.5	  |     17.4
      [8, 4096, 1, 8]         |     18.7	  |     17.4
      [8, 4096, 1, 64]        |     18.6	  |     17.2
      [8, 4096, 1, 512]       |    177.1	  |    173.4
      [8, 4096, 1, 4096]      |   1407.9	  |   1382.8
      [8, 4096, 8, 1]         |     19.0	  |     17.7
      [8, 4096, 8, 8]         |     24.6	  |     22.7
      [8, 4096, 8, 64]        |     24.0	  |     22.5
      [8, 4096, 8, 512]       |     97.8	  |     97.8
      [8, 4096, 8, 4096]      |    756.3	  |    757.9
      [8, 4096, 64, 1]        |     18.9	  |     16.9
      [8, 4096, 64, 8]        |     24.4	  |     22.4
      [8, 4096, 64, 64]       |     23.8	  |     22.3
      [8, 4096, 64, 512]      |    104.6	  |    104.6
      [8, 4096, 64, 4096]     |    776.5	  |    777.2
      [8, 4096, 512, 1]       |     31.3	  |     30.7
      [8, 4096, 512, 8]       |     33.4	  |     31.9
      [8, 4096, 512, 64]      |     48.4	  |     48.2
      [8, 4096, 512, 512]     |    161.4	  |    163.0
      [8, 4096, 512, 4096]    |   1155.2	  |   1203.3
      [8, 4096, 4096, 1]      |    194.3	  |    194.7
      [8, 4096, 4096, 8]      |    178.9	  |    178.1
      [8, 4096, 4096, 64]     |    199.6	  |    204.2
      [8, 4096, 4096, 512]    |    846.6	  |    843.4
      [8, 4096, 4096, 4096]   |   6927.5	  |   6793.9
      [64, 1, 1, 1]           |     18.3	  |     17.0
      [64, 1, 1, 8]           |     18.1	  |     16.8
      [64, 1, 1, 64]          |     18.5	  |     17.1
      [64, 1, 1, 512]         |     18.9	  |     17.5
      [64, 1, 1, 4096]        |     18.4	  |     18.5
      [64, 1, 8, 1]           |     18.0	  |     16.8
      [64, 1, 8, 8]           |     18.5	  |     17.0
      [64, 1, 8, 64]          |     18.4	  |     17.0
      [64, 1, 8, 512]         |     19.5	  |     17.4
      [64, 1, 8, 4096]        |     20.7	  |     20.5
      [64, 1, 64, 1]          |     18.1	  |     16.9
      [64, 1, 64, 8]          |     18.3	  |     17.1
      [64, 1, 64, 64]         |     18.3	  |     17.2
      [64, 1, 64, 512]        |     19.0	  |     17.5
      [64, 1, 64, 4096]       |     47.1	  |     46.5
      [64, 1, 512, 1]         |     18.3	  |     17.3
      [64, 1, 512, 8]         |     18.3	  |     17.3
      [64, 1, 512, 64]        |     18.5	  |     17.0
      [64, 1, 512, 512]       |     37.6	  |     37.4
      [64, 1, 512, 4096]      |    250.4	  |    251.1
      [64, 1, 4096, 1]        |     18.6	  |     17.0
      [64, 1, 4096, 8]        |     18.4	  |     18.2
      [64, 1, 4096, 64]       |     34.7	  |     34.3
      [64, 1, 4096, 512]      |    260.4	  |    259.6
      [64, 1, 4096, 4096]     |   1986.2	  |   1985.9
      [64, 8, 1, 1]           |     19.0	  |     17.7
      [64, 8, 1, 8]           |     19.1	  |     17.7
      [64, 8, 1, 64]          |     19.2	  |     17.6
      [64, 8, 1, 512]         |     18.7	  |     18.4
      [64, 8, 1, 4096]        |     23.5	  |     23.9
      [64, 8, 8, 1]           |     18.6	  |     17.1
      [64, 8, 8, 8]           |     19.1	  |     17.5
      [64, 8, 8, 64]          |     19.2	  |     17.7
      [64, 8, 8, 512]         |     19.7	  |     17.7
      [64, 8, 8, 4096]        |     35.7	  |     36.0
      [64, 8, 64, 1]          |     18.5	  |     16.9
      [64, 8, 64, 8]          |     19.0	  |     17.7
      [64, 8, 64, 64]         |     18.8	  |     17.4
      [64, 8, 64, 512]        |     19.3	  |     17.9
      [64, 8, 64, 4096]       |     82.4	  |     82.6
      [64, 8, 512, 1]         |     18.3	  |     17.1
      [64, 8, 512, 8]         |     18.7	  |     17.5
      [64, 8, 512, 64]        |     19.0	  |     17.9
      [64, 8, 512, 512]       |     32.0	  |     32.9
      [64, 8, 512, 4096]      |    196.5	  |    196.5
      [64, 8, 4096, 1]        |     18.1	  |     17.2
      [64, 8, 4096, 8]        |     22.4	  |     22.6
      [64, 8, 4096, 64]       |     68.0	  |     67.8
      [64, 8, 4096, 512]      |    183.7	  |    184.2
      [64, 8, 4096, 4096]     |   1300.7	  |   1300.5
      [64, 64, 1, 1]          |     18.8	  |     17.4
      [64, 64, 1, 8]          |     19.2	  |     17.5
      [64, 64, 1, 64]         |     19.5	  |     17.6
      [64, 64, 1, 512]        |     20.3	  |     17.3
      [64, 64, 1, 4096]       |    185.1	  |    180.8
      [64, 64, 8, 1]          |     18.3	  |     17.0
      [64, 64, 8, 8]          |     18.8	  |     17.5
      [64, 64, 8, 64]         |     19.2	  |     17.9
      [64, 64, 8, 512]        |     24.1	  |     22.5
      [64, 64, 8, 4096]       |    103.2	  |    102.9
      [64, 64, 64, 1]         |     18.5	  |     17.0
      [64, 64, 64, 8]         |     19.0	  |     17.4
      [64, 64, 64, 64]        |     19.2	  |     17.6
      [64, 64, 64, 512]       |     24.4	  |     22.6
      [64, 64, 64, 4096]      |    124.0	  |    124.0
      [64, 64, 512, 1]        |     18.6	  |     16.9
      [64, 64, 512, 8]        |     18.9	  |     17.4
      [64, 64, 512, 64]       |     19.2	  |     17.9
      [64, 64, 512, 512]      |     47.5	  |     47.5
      [64, 64, 512, 4096]     |    279.5	  |    279.3
      [64, 64, 4096, 1]       |     31.6	  |     30.6
      [64, 64, 4096, 8]       |     32.3	  |     31.7
      [64, 64, 4096, 64]      |     87.8	  |     87.7
      [64, 64, 4096, 512]     |    209.0	  |    208.5
      [64, 64, 4096, 4096]    |   1492.2	  |   1486.9
      [64, 512, 1, 1]         |     18.6	  |     17.5
      [64, 512, 1, 8]         |     18.7	  |     17.2
      [64, 512, 1, 64]        |     18.6	  |     17.2
      [64, 512, 1, 512]       |    177.3	  |    173.4
      [64, 512, 1, 4096]      |   1404.7	  |   1379.4
      [64, 512, 8, 1]         |     18.9	  |     17.3
      [64, 512, 8, 8]         |     24.5	  |     22.9
      [64, 512, 8, 64]        |     23.9	  |     22.9
      [64, 512, 8, 512]       |     98.3	  |     98.4
      [64, 512, 8, 4096]      |    758.2	  |    758.1
      [64, 512, 64, 1]        |     19.0	  |     17.0
      [64, 512, 64, 8]        |     24.5	  |     22.4
      [64, 512, 64, 64]       |     24.1	  |     22.4
      [64, 512, 64, 512]      |    108.2	  |    108.0
      [64, 512, 64, 4096]     |    785.5	  |    786.1
      [64, 512, 512, 1]       |     32.3	  |     31.5
      [64, 512, 512, 8]       |     34.5	  |     35.4
      [64, 512, 512, 64]      |     50.4	  |     50.5
      [64, 512, 512, 512]     |    170.7	  |    167.9
      [64, 512, 512, 4096]    |   1228.5	  |   1225.8
      [64, 512, 4096, 1]      |    201.8	  |    202.4
      [64, 512, 4096, 8]      |    180.4	  |    180.0
      [64, 512, 4096, 64]     |    218.4	  |    216.3
      [64, 512, 4096, 512]    |    921.0	  |    920.9
      [64, 512, 4096, 4096]   |   7440.4	  |   7470.5
      [64, 4096, 1, 1]        |     18.6	  |     18.5
      [64, 4096, 1, 8]        |     38.6	  |     35.2
      [64, 4096, 1, 64]       |    156.5	  |    151.5
      [64, 4096, 1, 512]      |   1357.6	  |   1324.9
      [64, 4096, 1, 4096]     |  11298.9	  |  11102.3
      [64, 4096, 8, 1]        |     54.1	  |     54.3
      [64, 4096, 8, 8]        |     41.6	  |     41.5
      [64, 4096, 8, 64]       |    101.0	  |    101.0
      [64, 4096, 8, 512]      |    748.2	  |    747.7
      [64, 4096, 8, 4096]     |   5952.2	  |   5981.3
      [64, 4096, 64, 1]       |     79.1	  |     78.9
      [64, 4096, 64, 8]       |     65.4	  |     65.4
      [64, 4096, 64, 64]      |    125.1	  |    125.0
      [64, 4096, 64, 512]     |    778.7	  |    777.7
      [64, 4096, 64, 4096]    |   6133.5	  |   6139.5
      [64, 4096, 512, 1]      |    212.0	  |    212.7
      [64, 4096, 512, 8]      |    223.5	  |    224.6
      [64, 4096, 512, 64]     |    291.1	  |    290.2
      [64, 4096, 512, 512]    |   1234.5	  |   1269.3
      [64, 4096, 512, 4096]   |   9472.6	  |   9478.0
      [64, 4096, 4096, 1]     |   1434.1	  |   1441.5
      [64, 4096, 4096, 8]     |   1411.6	  |   1411.9
      [64, 4096, 4096, 64]    |   1647.0	  |   1657.0
      [64, 4096, 4096, 512]   |   6705.9	  |   6536.0
      [64, 4096, 4096, 4096]  |  59374.6	  |  58433.4
      [512, 1, 1, 1]          |     18.2	  |     17.1
      [512, 1, 1, 8]          |     18.5	  |     17.0
      [512, 1, 1, 64]         |     19.4	  |     17.2
      [512, 1, 1, 512]        |     18.7	  |     18.7
      [512, 1, 1, 4096]       |     95.2	  |    103.8
      [512, 1, 8, 1]          |     18.5	  |     16.7
      [512, 1, 8, 8]          |     18.5	  |     16.9
      [512, 1, 8, 64]         |     18.3	  |     17.1
      [512, 1, 8, 512]        |     30.8	  |     30.7
      [512, 1, 8, 4096]       |    172.7	  |    173.1
      [512, 1, 64, 1]         |     18.5	  |     17.0
      [512, 1, 64, 8]         |     19.1	  |     16.9
      [512, 1, 64, 64]        |     19.3	  |     16.8
      [512, 1, 64, 512]       |     56.2	  |     56.1
      [512, 1, 64, 4096]      |    331.8	  |    331.4
      [512, 1, 512, 1]        |     18.5	  |     16.9
      [512, 1, 512, 8]        |     18.6	  |     17.7
      [512, 1, 512, 64]       |     28.1	  |     32.9
      [512, 1, 512, 512]      |    255.8	  |    256.5
      [512, 1, 512, 4096]     |   1939.5	  |   1941.0
      [512, 1, 4096, 1]       |     18.5	  |     16.6
      [512, 1, 4096, 8]       |     32.7	  |     35.1
      [512, 1, 4096, 64]      |    176.9	  |    176.9
      [512, 1, 4096, 512]     |   2012.7	  |   2003.9
      [512, 8, 1, 1]          |     19.4	  |     17.6
      [512, 8, 1, 8]          |     19.1	  |     17.4
      [512, 8, 1, 64]         |     19.3	  |     18.4
      [512, 8, 1, 512]        |     23.5	  |     23.6
      [512, 8, 1, 4096]       |    225.5	  |    226.2
      [512, 8, 8, 1]          |     18.4	  |     18.0
      [512, 8, 8, 8]          |     19.8	  |     17.9
      [512, 8, 8, 64]         |     19.2	  |     18.4
      [512, 8, 8, 512]        |     35.1	  |     35.4
      [512, 8, 8, 4096]       |    345.6	  |    346.9
      [512, 8, 64, 1]         |     18.4	  |     17.0
      [512, 8, 64, 8]         |     20.3	  |     18.1
      [512, 8, 64, 64]        |     19.6	  |     17.7
      [512, 8, 64, 512]       |     81.1	  |     81.5
      [512, 8, 64, 4096]      |    632.8	  |    639.8
      [512, 8, 512, 1]        |     18.4	  |     17.4
      [512, 8, 512, 8]        |     19.1	  |     17.6
      [512, 8, 512, 64]       |     36.3	  |     36.2
      [512, 8, 512, 512]      |    200.7	  |    201.0
      [512, 8, 512, 4096]     |   1425.9	  |   1425.7
      [512, 8, 4096, 1]       |     34.6	  |     33.3
      [512, 8, 4096, 8]       |    106.5	  |    106.2
      [512, 8, 4096, 64]      |    239.2	  |    237.3
      [512, 8, 4096, 512]     |   1343.7	  |   1343.0
      [512, 64, 1, 1]         |     19.1	  |     17.8
      [512, 64, 1, 8]         |     19.5	  |     18.3
      [512, 64, 1, 64]        |     19.2	  |     19.0
      [512, 64, 1, 512]       |    177.0	  |    173.3
      [512, 64, 1, 4096]      |   1403.0	  |   1376.7
      [512, 64, 8, 1]         |     18.6	  |     17.1
      [512, 64, 8, 8]         |     18.6	  |     17.4
      [512, 64, 8, 64]        |     18.4	  |     17.2
      [512, 64, 8, 512]       |    102.0	  |    101.8
      [512, 64, 8, 4096]      |    773.3	  |    772.6
      [512, 64, 64, 1]        |     18.5	  |     17.0
      [512, 64, 64, 8]        |     18.8	  |     17.3
      [512, 64, 64, 64]       |     18.4	  |     17.5
      [512, 64, 64, 512]      |    125.5	  |    125.2
      [512, 64, 64, 4096]     |    890.0	  |    890.7
      [512, 64, 512, 1]       |     29.0	  |     28.8
      [512, 64, 512, 8]       |     40.3	  |     40.5
      [512, 64, 512, 64]      |     62.9	  |     62.7
      [512, 64, 512, 512]     |    299.4	  |    299.6
      [512, 64, 512, 4096]    |   2118.3	  |   2117.9
      [512, 64, 4096, 1]      |    187.5	  |    187.6
      [512, 64, 4096, 8]      |    220.5	  |    220.3
      [512, 64, 4096, 64]     |    346.7	  |    346.7
      [512, 64, 4096, 512]    |   1599.2	  |   1596.2
      [512, 512, 1, 1]        |     20.1	  |     18.7
      [512, 512, 1, 8]        |     38.6	  |     35.3
      [512, 512, 1, 64]       |    156.5	  |    151.5
      [512, 512, 1, 512]      |   1357.0	  |   1325.1
      [512, 512, 1, 4096]     |  11278.8	  |  11072.9
      [512, 512, 8, 1]        |     57.5	  |     57.4
      [512, 512, 8, 8]        |     40.6	  |     40.6
      [512, 512, 8, 64]       |    101.4	  |    101.2
      [512, 512, 8, 512]      |    749.5	  |    749.2
      [512, 512, 8, 4096]     |   5960.5	  |   5958.5
      [512, 512, 64, 1]       |     83.2	  |     83.0
      [512, 512, 64, 8]       |     65.9	  |     65.3
      [512, 512, 64, 64]      |    126.1	  |    126.0
      [512, 512, 64, 512]     |    791.5	  |    791.5
      [512, 512, 64, 4096]    |   6206.3	  |   6211.8
      [512, 512, 512, 1]      |    222.7	  |    223.5
      [512, 512, 512, 8]      |    223.0	  |    222.3
      [512, 512, 512, 64]     |    304.6	  |    305.2
      [512, 512, 512, 512]    |   1309.7	  |   1321.7
      [512, 512, 512, 4096]   |  10268.2	  |  10121.3
      [512, 512, 4096, 1]     |   1504.0	  |   1504.3
      [512, 512, 4096, 8]     |   1426.6	  |   1424.1
      [512, 512, 4096, 64]    |   1745.4	  |   1744.1
      [512, 512, 4096, 512]   |   7469.2	  |   7473.3
      [512, 4096, 1, 1]       |     95.9	  |    105.0
      [512, 4096, 1, 8]       |    342.9	  |    325.5
      [512, 4096, 1, 64]      |   1181.4	  |   1140.8
      [512, 4096, 1, 512]     |  10907.3	  |  10645.3
      [512, 4096, 8, 1]       |    512.7	  |    514.0
      [512, 4096, 8, 8]       |    369.8	  |    370.4
      [512, 4096, 8, 64]      |    752.5	  |    752.4
      [512, 4096, 8, 512]     |   5867.5	  |   5867.0
      [512, 4096, 64, 1]      |    601.2	  |    601.1
      [512, 4096, 64, 8]      |    441.0	  |    440.6
      [512, 4096, 64, 64]     |    872.8	  |    872.6
      [512, 4096, 64, 512]    |   6140.8	  |   6140.6
      [512, 4096, 512, 1]     |   1655.5	  |   1657.2
      [512, 4096, 512, 8]     |   1612.1	  |   1611.7
      [512, 4096, 512, 64]    |   2212.1	  |   2213.7
      [512, 4096, 512, 512]   |  10142.6	  |   9967.6
	
Times are in microseconds (us).	s).

facebook-github-bot pushed a commit that referenced this pull request Nov 10, 2021
…67946)

Summary:
#67578 disabled reduced precision reductions for FP16 GEMMs. After benchmarking, we've found that this has substantial performance impacts for common GEMM shapes (e.g., those found in popular instantiations of multiheaded-attention) on architectures such as Volta. As these performance regressions may come as a surprise to current users, this PR adds a toggle to disable reduced precision reductions
`torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = `
rather than making it the default behavior.

CC ngimel ptrblck
stas00 Note that the behavior after the previous PR can be replicated with
`torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False`

Pull Request resolved: #67946

Reviewed By: zou3519

Differential Revision: D32289896

Pulled By: ngimel

fbshipit-source-id: a1ea2918b77e27a7d9b391e030417802a0174abe
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants