Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Nov 6, 2021

#67578 disabled reduced precision reductions for FP16 GEMMs. After benchmarking, we've found that this has substantial performance impacts for common GEMM shapes (e.g., those found in popular instantiations of multiheaded-attention) on architectures such as Volta. As these performance regressions may come as a surprise to current users, this PR adds a toggle to disable reduced precision reductions
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction =
rather than making it the default behavior.

CC @ngimel @ptrblck
@stas00 Note that the behavior after the previous PR can be replicated with
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

@pytorch-probot
Copy link

pytorch-probot bot commented Nov 6, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/eqy/pytorch/blob/b852f29f20efa21a98f679693292db9f563d1633/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-dynamic ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-xenial-py3-clang5-mobile-code-analysis ciflow/all, ciflow/linux, ciflow/mobile 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Nov 6, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit b852f29 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-scanned failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Nov 09 17:58:00 ERROR [0.047s]: test_advancedin..._cpu_devices_xla (__main__.TestDevicePrecisionXLA)
Nov 09 17:58:00   test_from_sequence_xla_int16 (__main__.TestDevicePrecisionXLA) ... ok (0.014s)
Nov 09 17:58:00   test_from_sequence_xla_int32 (__main__.TestDevicePrecisionXLA) ... ok (0.013s)
Nov 09 17:58:00   test_from_sequence_xla_int64 (__main__.TestDevicePrecisionXLA) ... ok (0.013s)
Nov 09 17:58:00   test_from_sequence_xla_int8 (__main__.TestDevicePrecisionXLA) ... ok (0.013s)
Nov 09 17:58:00   test_from_sequence_xla_uint8 (__main__.TestDevicePrecisionXLA) ... ok (0.013s)
Nov 09 17:58:00   test_index_add_bfloat16_xla (__main__.TestDevicePrecisionXLA) ... skip (0.002s)
Nov 09 17:58:00   test_multidevice_serialization_xla (__main__.TestDevicePrecisionXLA) ... skip (0.001s)
Nov 09 17:58:00   test_type_conversions_same_device_xla (__main__.TestDevicePrecisionXLA) ... skip (0.001s)
Nov 09 17:58:00 
Nov 09 17:58:00 ======================================================================
Nov 09 17:58:00 ERROR [0.047s]: test_advancedindex_mixed_cpu_devices_xla (__main__.TestDevicePrecisionXLA)
Nov 09 17:58:00 ----------------------------------------------------------------------
Nov 09 17:58:00 Traceback (most recent call last):
Nov 09 17:58:00   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 376, in instantiated_test
Nov 09 17:58:00     raise rte
Nov 09 17:58:00   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 371, in instantiated_test
Nov 09 17:58:00     result = test(self, **param_kwargs)
Nov 09 17:58:00   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 915, in multi_fn
Nov 09 17:58:00     return fn(slf, devices, *args, **kwargs)
Nov 09 17:58:00   File "/var/lib/jenkins/workspace/xla/test/../../test/test_torch.py", line 8203, in test_advancedindex_mixed_cpu_devices
Nov 09 17:58:00     test(x, ia, ib)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@eqy
Copy link
Collaborator Author

eqy commented Nov 6, 2021

Some GEMM shapes benchmarked on V100:

loaded 574 shapes
[--------------------------- bench_gemm_transformer --------------------------]
                            |  allow_fp16_reduc=True  |  allow_fp16_reduc=False
1 threads: --------------------------------------------------------------------
      [1024, 8, 1024]       |             37.2        |             13.3       
      [1024, 12, 1024]      |             13.5        |             12.8       
      [1024, 16, 1024]      |             18.4        |             17.1       
      [1024, 28, 1024]      |             17.7        |             16.7       
      [1024, 36, 1024]      |             18.0        |             16.8       
      [1024, 48, 1024]      |             18.0        |             17.0       
      [1024, 52, 1024]      |             17.8        |             16.7       
      [1024, 56, 1024]      |             17.9        |             17.3       
      [1024, 60, 1024]      |             17.6        |             16.6       
      [1024, 64, 1024]      |             18.2        |             17.2       
      [1024, 68, 1024]      |             17.4        |             16.6       
      [1024, 72, 1024]      |             17.1        |             16.1       
      [1024, 76, 1024]      |             17.4        |             16.7       
      [1024, 84, 1024]      |             17.4        |             16.7       
      [1024, 88, 1024]      |             17.9        |             16.6       
      [1024, 92, 1024]      |             17.6        |             16.6       
      [1024, 100, 1024]     |             17.5        |             16.6       
      [1024, 112, 1024]     |             17.8        |             16.6       
      [1024, 120, 1024]     |             17.6        |             16.6       
      [1024, 128, 1024]     |             18.5        |             17.5       
      [1024, 136, 1024]     |             18.3        |             17.7       
      [1024, 140, 1024]     |             17.5        |             16.6       
      [1024, 156, 1024]     |             17.6        |             16.6       
      [1024, 160, 1024]     |             18.5        |             17.7       
      [1024, 164, 1024]     |             17.5        |             16.6       
      [1024, 168, 1024]     |             18.7        |             17.5       
      [1024, 172, 1024]     |             17.9        |             17.0       
      [1024, 180, 1024]     |             17.9        |             16.9       
      [1024, 184, 1024]     |             18.5        |             17.6       
      [1024, 188, 1024]     |             18.0        |             17.1       
      [1024, 192, 1024]     |             17.3        |             16.2       
      [1024, 200, 1024]     |             17.4        |             16.3       
      [1024, 204, 1024]     |             18.2        |             17.3       
      [1024, 208, 1024]     |             16.7        |             15.8       
      [1024, 212, 1024]     |             18.3        |             17.1       
      [1024, 224, 1024]     |             18.7        |             18.0       
      [1024, 228, 1024]     |             18.3        |             17.2       
      [1024, 232, 1024]     |             20.3        |             19.3       
      [1024, 236, 1024]     |             18.5        |             17.2       
      [1024, 244, 1024]     |             18.4        |             17.3       
      [1024, 252, 1024]     |             18.3        |             17.6       
      [1024, 260, 1024]     |             18.4        |             17.7       
      [1024, 264, 1024]     |             20.5        |             19.1       
      [1024, 268, 1024]     |             18.6        |             18.0       
      [1024, 272, 1024]     |             20.1        |             19.4       
      [1024, 280, 1024]     |             20.1        |             19.2       
      [1024, 292, 1024]     |             19.4        |             19.2       
      [1024, 296, 1024]     |             19.9        |             18.9       
      [1024, 300, 1024]     |             19.3        |             19.3       
      [1024, 304, 1024]     |             20.3        |             19.4       
      [1024, 316, 1024]     |             19.8        |             19.8       
      [1024, 332, 1024]     |             20.6        |             20.6       
      [1024, 336, 1024]     |             20.0        |             18.9       
      [1024, 356, 1024]     |             22.0        |             22.1       
      [1024, 376, 1024]     |             19.8        |             18.6       
      [1024, 384, 1024]     |             19.8        |             18.8       
      [1024, 388, 1024]     |             23.2        |             23.3       
      [1024, 400, 1024]     |             20.2        |             18.6       
      [1024, 408, 1024]     |             20.1        |             19.0       
      [1024, 428, 1024]     |             24.5        |             24.5       
      [1024, 432, 1024]     |             20.1        |             19.5       
      [1024, 436, 1024]     |             24.5        |             24.6       
      [1024, 440, 1024]     |             19.9        |             19.8       
      [1024, 456, 1024]     |             20.6        |             20.8       
      [1024, 464, 1024]     |             20.4        |             20.5       
      [1024, 476, 1024]     |             26.0        |             26.0       
      [1024, 484, 1024]     |             26.9        |             27.3       
      [1024, 488, 1024]     |             21.7        |             22.0       
      [1024, 496, 1024]     |             21.4        |             21.6       
      [1024, 500, 1024]     |             26.9        |             27.0       
      [1024, 504, 1024]     |             21.9        |             22.1       
      [1024, 508, 1024]     |             27.0        |             27.2       
      [1024, 512, 1024]     |             21.1        |             21.2       
      [1024, 560, 1024]     |             23.6        |             23.5       
      [1024, 720, 1024]     |             29.0        |             29.0       
      [1024, 896, 1024]     |             34.0        |             34.3       
      [1024, 960, 1024]     |             36.4        |             36.3       
      [1024, 1024, 1024]    |             38.7        |             38.7       
      [1024, 1040, 1024]    |             39.8        |             40.1       
      [1024, 1152, 1024]    |             43.1        |             43.1       
      [1024, 1176, 1024]    |             45.3        |             45.3       
      [1024, 1248, 1024]    |             46.1        |             46.6       
      [1024, 1344, 1024]    |             49.6        |             49.9       
      [1024, 1680, 1024]    |             61.7        |             62.0       
      [1024, 1728, 1024]    |             62.2        |             62.7       
      [1024, 1768, 1024]    |             66.2        |             66.4       
      [1024, 1840, 1024]    |             67.1        |             67.3       
      [1024, 1848, 1024]    |             68.4        |             68.8       
      [1024, 1920, 1024]    |             68.9        |             69.2       
      [1024, 1936, 1024]    |             70.3        |             70.8       
      [1024, 1984, 1024]    |             70.9        |             71.1       
      [1024, 2072, 1024]    |             71.1        |             98.0       
      [1024, 2128, 1024]    |             72.7        |             97.9       
      [1024, 2176, 1024]    |             72.5        |             97.7       
      [1024, 2232, 1024]    |             74.3        |            104.6       
      [1024, 2240, 1024]    |             74.4        |            100.8       
      [1024, 2304, 1024]    |             75.8        |            103.6       
      [1024, 2496, 1024]    |             86.6        |            111.3       
      [1024, 2512, 1024]    |             87.3        |            114.6       
      [1024, 2520, 1024]    |             87.4        |            118.1       
      [1024, 2560, 1024]    |             86.7        |            114.2       
      [1024, 2640, 1024]    |             90.5        |            120.3       
      [1024, 2688, 1024]    |             89.7        |            119.3       
      [1024, 2800, 1024]    |             94.5        |            127.3       
      [1024, 2816, 1024]    |             94.2        |            125.6       
      [1024, 2880, 1024]    |             94.6        |            128.2       
      [1024, 2960, 1024]    |             98.5        |            133.6       
      [1024, 2968, 1024]    |             98.0        |            137.5       
      [1024, 3000, 1024]    |             98.7        |            139.0       
      [1024, 3040, 1024]    |             98.3        |            133.0       
      [1024, 3072, 1024]    |            101.4        |            135.4       
      [1024, 3136, 1024]    |            102.0        |            139.8       
      [1024, 3192, 1024]    |            103.6        |            146.2       
      [1024, 3200, 1024]    |            102.8        |            141.2       
      [1024, 3256, 1024]    |            106.2        |            150.0       
      [1024, 3312, 1024]    |            107.0        |            148.9       
      [1024, 3328, 1024]    |            105.7        |            146.1       
      [1024, 3344, 1024]    |            107.0        |            150.3       
      [1024, 3392, 1024]    |            110.1        |            149.0       
      [1024, 3456, 1024]    |            109.6        |            151.5       
      [1024, 3520, 1024]    |            110.0        |            154.3       
      [1024, 3536, 1024]    |            113.0        |            158.2       
      [1024, 3584, 1024]    |            113.2        |            156.6       
      [1024, 3600, 1024]    |            114.1        |            160.5       
      [1024, 3632, 1024]    |            114.2        |            161.9       
      [1024, 3680, 1024]    |            113.6        |            159.3       
      [1024, 3696, 1024]    |            117.8        |            165.4       
      [1024, 3712, 1024]    |            117.4        |            162.7       
      [1024, 3720, 1024]    |            117.9        |            170.0       
      [1024, 3744, 1024]    |            116.7        |            161.2       
      [1024, 3792, 1024]    |            118.2        |            169.2       
      [1024, 3816, 1024]    |            118.1        |            174.6       
      [1024, 3848, 1024]    |            120.8        |            177.4       
      [1024, 3904, 1024]    |            120.7        |            171.3       
      [1024, 3920, 1024]    |            121.3        |            173.9       
      [1024, 3952, 1024]    |            121.5        |            175.9       
      [1024, 3960, 1024]    |            121.6        |            180.3       
      [1024, 4000, 1024]    |            120.6        |            172.1       
      [1024, 4032, 1024]    |            124.6        |            175.7       
      [1024, 4048, 1024]    |            124.4        |            180.2       
      [1024, 4056, 1024]    |            124.9        |            184.7       
      [1024, 4080, 1024]    |            125.4        |            180.7       
      [1024, 4096, 1024]    |            124.7        |            177.8       
      [1024, 4104, 1024]    |            126.0        |            186.2       
      [1024, 4128, 1024]    |            124.1        |            175.9       
      [1024, 4144, 1024]    |            143.5        |            183.5       
      [1024, 4160, 1024]    |            142.4        |            181.9       
      [1024, 4200, 1024]    |            145.9        |            191.6       
      [1024, 4224, 1024]    |            142.6        |            183.7       
      [1024, 4240, 1024]    |            145.4        |            187.4       
      [1024, 4264, 1024]    |            147.8        |            193.8       
      [1024, 4280, 1024]    |            148.0        |            194.4       
      [1024, 4320, 1024]    |            143.7        |            185.2       
      [1024, 4352, 1024]    |            146.6        |            189.1       
      [1024, 4360, 1024]    |            150.6        |            198.6       
      [1024, 4368, 1024]    |            148.4        |            194.1       
      [1024, 4392, 1024]    |            151.3        |            200.3       
      [1024, 4400, 1024]    |            149.1        |            195.6       
      [1024, 4416, 1024]    |            147.9        |            191.2       
      [1024, 4424, 1024]    |            158.2        |            201.6       
      [1024, 4440, 1024]    |            158.6        |            201.9       
      [1024, 4464, 1024]    |            158.3        |            198.5       
      [1024, 4480, 1024]    |            158.2        |            195.2       
      [1024, 4488, 1024]    |            159.3        |            204.6       
      [1024, 4512, 1024]    |            158.9        |            193.4       
      [1024, 4536, 1024]    |            161.4        |            205.4       
      [1024, 4560, 1024]    |            158.3        |            202.8       
      [1024, 4592, 1024]    |            159.8        |            203.5       
      [1024, 4600, 1024]    |            161.1        |            207.0       
      [1024, 4608, 1024]    |            158.4        |            200.7       
      [1024, 4648, 1024]    |            165.3        |            210.6       
      [1024, 4656, 1024]    |            165.0        |            205.0       
      [1024, 4664, 1024]    |            166.4        |            211.9       
      [1024, 4672, 1024]    |            166.5        |            202.7       
      [1024, 4680, 1024]    |            167.4        |            212.0       
      [1024, 4704, 1024]    |            163.4        |            200.8       
      [1024, 4720, 1024]    |            165.1        |            209.8       
      [1024, 4736, 1024]    |            165.5        |            205.5       
      [1024, 4752, 1024]    |            166.2        |            212.0       
      [1024, 4760, 1024]    |            165.1        |            215.9       
      [1024, 4784, 1024]    |            166.7        |            212.6       
      [1024, 4800, 1024]    |            167.8        |            208.8       
      [1024, 4816, 1024]    |            168.2        |            213.4       
      [1024, 4824, 1024]    |            169.0        |            218.1       
      [1024, 4840, 1024]    |            167.8        |            219.2       
      [1024, 4864, 1024]    |            164.6        |            211.0       
      [1024, 4872, 1024]    |            168.8        |            221.1       
      [1024, 4880, 1024]    |            167.6        |            215.6       
      [1024, 4888, 1024]    |            169.1        |            221.8       
      [1024, 4896, 1024]    |            166.1        |            209.1       
      [1024, 4920, 1024]    |            172.7        |            223.6       
      [1024, 4928, 1024]    |            171.8        |            215.2       
      [1024, 4960, 1024]    |            173.3        |            212.3       
      [1024, 4968, 1024]    |            175.6        |            226.2       
      [1024, 4984, 1024]    |            173.0        |            225.0       
      [1024, 4992, 1024]    |            171.5        |            215.8       
      [1024, 5000, 1024]    |            174.2        |            227.2       
      [1024, 5016, 1024]    |            174.7        |            226.9       
      [1024, 5032, 1024]    |            174.6        |            228.3       
      [1024, 5040, 1024]    |            173.3        |            223.5       
      [1024, 5056, 1024]    |            173.4        |            219.5       
      [1024, 5088, 1024]    |            174.3        |            217.8       
      [1024, 5096, 1024]    |            175.7        |            230.5       
      [1024, 5104, 1024]    |            176.1        |            225.6       
      [1024, 5112, 1024]    |            176.0        |            231.3       
      [1024, 5120, 1024]    |            170.8        |            222.8       
      [1024, 9728, 1024]    |            284.0        |            422.3       
      [1024, 16384, 1024]   |            449.7        |            710.0       
      [1024, 33712, 1024]   |            832.0        |           1491.6       
      [4096, 8, 4096]       |             52.5        |             52.6       
      [4096, 12, 4096]      |             67.3        |             67.3       
      [4096, 16, 4096]      |             53.0        |             53.4       
      [4096, 28, 4096]      |             70.7        |             71.0       
      [4096, 36, 4096]      |             87.8        |             87.9       
      [4096, 48, 4096]      |             64.2        |             64.4       
      [4096, 52, 4096]      |             91.9        |             92.1       
      [4096, 56, 4096]      |             64.6        |             64.7       
      [4096, 60, 4096]      |             88.9        |             89.0       
      [4096, 64, 4096]      |             64.4        |             64.4       
      [4096, 68, 4096]      |             98.7        |             98.4       
      [4096, 72, 4096]      |             74.4        |             74.5       
      [4096, 76, 4096]      |             99.7        |             99.7       
      [4096, 84, 4096]      |            101.2        |            101.0       
      [4096, 88, 4096]      |             75.4        |             75.3       
      [4096, 92, 4096]      |            103.6        |            103.5       
      [4096, 100, 4096]     |            113.0        |            112.3       
      [4096, 112, 4096]     |             85.9        |             86.1       
      [4096, 120, 4096]     |             87.6        |             87.5       
      [4096, 128, 4096]     |             86.5        |             86.6       
      [4096, 136, 4096]     |             97.9        |             98.0       
      [4096, 140, 4096]     |            127.6        |            127.4       
      [4096, 156, 4096]     |            132.0        |            131.9       
      [4096, 160, 4096]     |            103.1        |            103.3       
      [4096, 164, 4096]     |            142.3        |            142.6       
      [4096, 168, 4096]     |            111.0        |            111.1       
      [4096, 172, 4096]     |            142.5        |            143.8       
      [4096, 180, 4096]     |            144.6        |            144.2       
      [4096, 184, 4096]     |            112.5        |            112.8       
      [4096, 188, 4096]     |            147.8        |            146.9       
      [4096, 192, 4096]     |            110.5        |            110.7       
      [4096, 200, 4096]     |            123.1        |            123.2       
      [4096, 204, 4096]     |            157.6        |            158.0       
      [4096, 208, 4096]     |            122.6        |            122.7       
      [4096, 212, 4096]     |            160.2        |            159.3       
      [4096, 224, 4096]     |            122.4        |            122.7       
      [4096, 228, 4096]     |            169.4        |            170.7       
      [4096, 232, 4096]     |            136.0        |            136.6       
      [4096, 236, 4096]     |            171.6        |            171.4       
      [4096, 244, 4096]     |            171.8        |            170.8       
      [4096, 252, 4096]     |            173.3        |            173.8       
      [4096, 260, 4096]     |            182.2        |            182.3       
      [4096, 264, 4096]     |            155.2        |            155.3       
      [4096, 268, 4096]     |            184.9        |            184.0       
      [4096, 272, 4096]     |            154.7        |            154.8       
      [4096, 280, 4096]     |            150.8        |            151.1       
      [4096, 292, 4096]     |            197.4        |            200.8       
      [4096, 296, 4096]     |            163.0        |            162.8       
      [4096, 300, 4096]     |            201.5        |            200.0       
      [4096, 304, 4096]     |            160.6        |            161.0       
      [4096, 316, 4096]     |            205.0        |            206.7       
      [4096, 332, 4096]     |            214.6        |            216.2       
      [4096, 336, 4096]     |            176.7        |            176.5       
      [4096, 356, 4096]     |            228.7        |            229.0       
      [4096, 376, 4096]     |            193.3        |            193.5       
      [4096, 384, 4096]     |            189.3        |            189.9       
      [4096, 388, 4096]     |            243.2        |            243.4       
      [4096, 400, 4096]     |            218.4        |            217.8       
      [4096, 408, 4096]     |            221.3        |            221.3       
      [4096, 428, 4096]     |            258.6        |            258.5       
      [4096, 432, 4096]     |            231.7        |            231.7       
      [4096, 436, 4096]     |            259.5        |            260.1       
      [4096, 440, 4096]     |            233.5        |            233.7       
      [4096, 456, 4096]     |            246.0        |            245.6       
      [4096, 464, 4096]     |            243.9        |            244.8       
      [4096, 476, 4096]     |            274.0        |            275.6       
      [4096, 484, 4096]     |            283.5        |            283.7       
      [4096, 488, 4096]     |            259.6        |            260.4       
      [4096, 496, 4096]     |            258.2        |            257.3       
      [4096, 500, 4096]     |            284.9        |            284.6       
      [4096, 504, 4096]     |            261.2        |            259.7       
      [4096, 508, 4096]     |            287.9        |            287.0       
      [4096, 512, 4096]     |            258.0        |            255.6       
      [4096, 560, 4096]     |            282.7        |            279.8       
      [4096, 720, 4096]     |            341.9        |            340.5       
      [4096, 896, 4096]     |            404.3        |            403.9       
      [4096, 960, 4096]     |            433.8        |            433.4       
      [4096, 1024, 4096]    |            450.4        |            450.9       
      [4096, 1040, 4096]    |            467.1        |            466.2       
      [4096, 1152, 4096]    |            504.5        |            503.2       
      [4096, 1176, 4096]    |            518.1        |            517.3       
      [4096, 1248, 4096]    |            538.0        |            539.4       
      [4096, 1344, 4096]    |            578.0        |            578.7       
      [4096, 1680, 4096]    |            717.2        |            719.1       
      [4096, 1728, 4096]    |            727.8        |            728.9       
      [4096, 1768, 4096]    |            766.2        |            766.5       
      [4096, 1840, 4096]    |            784.0        |            784.7       
      [4096, 1848, 4096]    |            793.3        |            794.0       
      [4096, 1920, 4096]    |            806.7        |            810.2       
      [4096, 1936, 4096]    |            819.1        |            821.3       
      [4096, 1984, 4096]    |            833.8        |            836.8       
      [4096, 2072, 4096]    |            882.2        |            882.5       
      [4096, 2128, 4096]    |            894.7        |            894.1       
      [4096, 2176, 4096]    |            903.8        |            902.3       
      [4096, 2232, 4096]    |            945.0        |            945.1       
      [4096, 2240, 4096]    |            928.7        |            929.1       
      [4096, 2304, 4096]    |            951.2        |            952.1       
      [4096, 2496, 4096]    |           1025.6        |           1026.8       
      [4096, 2512, 4096]    |           1043.2        |           1044.4       
      [4096, 2520, 4096]    |           1056.4        |           1061.6       
      [4096, 2560, 4096]    |           1052.8        |           1055.8       
      [4096, 2640, 4096]    |           1091.4        |           1097.2       
      [4096, 2688, 4096]    |           1097.2        |           1100.0       
      [4096, 2800, 4096]    |           1158.9        |           1157.6       
      [4096, 2816, 4096]    |           1149.8        |           1148.1       
      [4096, 2880, 4096]    |           1177.5        |           1179.7       
      [4096, 2960, 4096]    |           1226.3        |           1215.2       
      [4096, 2968, 4096]    |           1241.1        |           1239.8       
      [4096, 3000, 4096]    |           1253.7        |           1253.9       
      [4096, 3040, 4096]    |           1240.9        |           1241.6       
      [4096, 3072, 4096]    |           1250.7        |           1248.2       
      [4096, 3136, 4096]    |           1276.7        |           1273.8       
      [4096, 3192, 4096]    |           1325.4        |           1323.3       
      [4096, 3200, 4096]    |           1297.3        |           1295.9       
      [4096, 3256, 4096]    |           1352.0        |           1351.3       
      [4096, 3312, 4096]    |           1354.3        |           1347.6       
      [4096, 3328, 4096]    |           1346.9        |           1340.4       
      [4096, 3344, 4096]    |           1369.2        |           1361.2       
      [4096, 3392, 4096]    |           1372.5        |           1369.3       
      [4096, 3456, 4096]    |           1391.6        |           1390.5       
      [4096, 3520, 4096]    |           1419.6        |           1414.7       
      [4096, 3536, 4096]    |           1443.5        |           1434.0       
      [4096, 3584, 4096]    |           1444.8        |           1442.7       
      [4096, 3600, 4096]    |           1463.0        |           1465.8       
      [4096, 3632, 4096]    |           1481.0        |           1473.9       
      [4096, 3680, 4096]    |           1477.6        |           1475.2       
      [4096, 3696, 4096]    |           1497.4        |           1502.3       
      [4096, 3712, 4096]    |           1488.8        |           1488.1       
      [4096, 3720, 4096]    |           1539.5        |           1538.7       
      [4096, 3744, 4096]    |           1502.8        |           1504.1       
      [4096, 3792, 4096]    |           1534.3        |           1546.3       
      [4096, 3816, 4096]    |           1571.1        |           1578.0       
      [4096, 3848, 4096]    |           1586.7        |           1578.1       
      [4096, 3904, 4096]    |           1565.8        |           1566.5       
      [4096, 3920, 4096]    |           1587.6        |           1586.0       
      [4096, 3952, 4096]    |           1605.5        |           1598.2       
      [4096, 3960, 4096]    |           1616.1        |           1625.1       
      [4096, 4000, 4096]    |           1593.9        |           1594.4       
      [4096, 4032, 4096]    |           1612.9        |           1621.0       
      [4096, 4048, 4096]    |           1634.6        |           1639.8       
      [4096, 4056, 4096]    |           1670.8        |           1661.9       
      [4096, 4080, 4096]    |           1664.2        |           1658.3       
      [4096, 4096, 4096]    |           1639.4        |           1651.0       
      [4096, 4104, 4096]    |           1677.4        |           1674.9       
      [4096, 4128, 4096]    |           1655.7        |           1646.0       
      [4096, 4144, 4096]    |           1796.8        |           2519.6       
      [4096, 4160, 4096]    |           1784.0        |           2497.7       
      [4096, 4200, 4096]    |           1823.1        |           2623.8       
      [4096, 4224, 4096]    |           1794.3        |           2536.7       
      [4096, 4240, 4096]    |           1835.3        |           2584.2       
      [4096, 4264, 4096]    |           1862.0        |           2670.8       
      [4096, 4280, 4096]    |           1859.8        |           2665.6       
      [4096, 4320, 4096]    |           1828.8        |           2549.5       
      [4096, 4352, 4096]    |           1869.3        |           2613.0       
      [4096, 4360, 4096]    |           1896.4        |           2736.1       
      [4096, 4368, 4096]    |           1885.7        |           2672.6       
      [4096, 4392, 4096]    |           1904.1        |           2758.7       
      [4096, 4400, 4096]    |           1882.9        |           2692.2       
      [4096, 4416, 4096]    |           1868.2        |           2644.9       
      [4096, 4424, 4096]    |           1933.6        |           2774.4       
      [4096, 4440, 4096]    |           1937.2        |           2776.5       
      [4096, 4464, 4096]    |           1916.9        |           2725.6       
      [4096, 4480, 4096]    |           1933.8        |           2693.3       
      [4096, 4488, 4096]    |           1944.5        |           2814.6       
      [4096, 4512, 4096]    |           1908.4        |           2658.7       
      [4096, 4536, 4096]    |           1979.0        |           2836.5       
      [4096, 4560, 4096]    |           1953.8        |           2793.4       
      [4096, 4592, 4096]    |           1911.5        |           2808.5       
      [4096, 4600, 4096]    |           1931.8        |           2867.9       
      [4096, 4608, 4096]    |           1895.7        |           2755.0       
      [4096, 4648, 4096]    |           1950.8        |           2909.9       
      [4096, 4656, 4096]    |           1945.0        |           2846.6       
      [4096, 4664, 4096]    |           1966.8        |           2916.5       
      [4096, 4672, 4096]    |           1952.1        |           2791.9       
      [4096, 4680, 4096]    |           1965.1        |           2936.0       
      [4096, 4704, 4096]    |           1930.9        |           2773.7       
      [4096, 4720, 4096]    |           1975.7        |           2884.7       
      [4096, 4736, 4096]    |           1955.3        |           2838.2       
      [4096, 4752, 4096]    |           1974.9        |           2908.3       
      [4096, 4760, 4096]    |           1994.2        |           2968.9       
      [4096, 4784, 4096]    |           1977.1        |           2926.8       
      [4096, 4800, 4096]    |           1974.1        |           2883.2       
      [4096, 4816, 4096]    |           2003.4        |           2937.5       
      [4096, 4824, 4096]    |           2031.1        |           3018.6       
      [4096, 4840, 4096]    |           2026.5        |           3030.3       
      [4096, 4864, 4096]    |           2011.5        |           2907.2       
      [4096, 4872, 4096]    |           2037.0        |           3044.1       
      [4096, 4880, 4096]    |           2005.8        |           2975.5       
      [4096, 4888, 4096]    |           2017.4        |           3061.8       
      [4096, 4896, 4096]    |           2003.3        |           2888.3       
      [4096, 4920, 4096]    |           2056.8        |           3076.7       
      [4096, 4928, 4096]    |           2028.5        |           2953.8       
      [4096, 4960, 4096]    |           2049.7        |           2923.6       
      [4096, 4968, 4096]    |           2059.4        |           3108.7       
      [4096, 4984, 4096]    |           2056.7        |           3114.6       
      [4096, 4992, 4096]    |           2032.2        |           2981.9       
      [4096, 5000, 4096]    |           2081.5        |           3126.8       
      [4096, 5016, 4096]    |           2095.4        |           3122.9       
      [4096, 5032, 4096]    |           2086.9        |           3146.6       
      [4096, 5040, 4096]    |           2079.1        |           3078.2       
      [4096, 5056, 4096]    |           2087.8        |           3028.3       
      [4096, 5088, 4096]    |           2056.2        |           3000.9       
      [4096, 5096, 4096]    |           2094.6        |           3190.0       
      [4096, 5104, 4096]    |           2144.0        |           2663.5       
      [4096, 5112, 4096]    |           2149.1        |           2766.9       
      [4096, 5120, 4096]    |           2142.8        |           2631.0       
      [4096, 9728, 4096]    |           3875.1        |           5779.8       
      [4096, 16384, 4096]   |           6182.9        |           9656.5       
      [33712, 8, 33712]     |           2866.3        |           2863.5       
      [33712, 12, 33712]    |           4301.2        |           4307.3       
      [33712, 16, 33712]    |           2943.2        |           2940.9       
      [33712, 28, 33712]    |           4452.9        |           4437.7       
      [33712, 36, 33712]    |           5151.5        |           5148.5       
      [33712, 48, 33712]    |           3997.6        |           3964.8       
      [33712, 52, 33712]    |           5363.9        |           5362.7       
      [33712, 56, 33712]    |           4091.1        |           4090.7       
      [33712, 60, 33712]    |           5489.4        |           5492.4       
      [33712, 64, 33712]    |           4168.2        |           4169.6       
      [33712, 68, 33712]    |           6153.4        |           6142.5       
      [33712, 72, 33712]    |           4861.7        |           4861.2       
      [33712, 76, 33712]    |           6249.9        |           6246.2       
      [33712, 84, 33712]    |           6370.2        |           6366.5       
      [33712, 88, 33712]    |           5034.6        |           5041.8       
      [33712, 92, 33712]    |           6511.4        |           6526.9       
      [33712, 100, 33712]   |           7149.4        |           7138.9       
      [33712, 112, 33712]   |           6185.1        |           6198.6       
      [33712, 120, 33712]   |           6296.6        |           6267.8       
      [33712, 128, 33712]   |           6247.6        |           6258.4       
      [33712, 136, 33712]   |           7040.2        |           6993.7       
      [33712, 140, 33712]   |           9922.3        |           9940.0       
      [33712, 156, 33712]   |          10237.9        |          10256.3       
      [33712, 160, 33712]   |           7082.4        |           7064.9       
      [33712, 164, 33712]   |          11094.9        |          11115.8       
      [33712, 168, 33712]   |           7822.4        |           7822.8       
      [33712, 172, 33712]   |          11167.6        |          11145.4       
      [33712, 180, 33712]   |          11364.7        |          11347.1       
      [33712, 184, 33712]   |           7925.3        |           7917.9       
      [33712, 188, 33712]   |          11560.4        |          11557.5       
      [33712, 192, 33712]   |           7917.5        |           7931.2       
      [33712, 200, 33712]   |           8689.0        |           8701.1       
      [33712, 204, 33712]   |          12501.3        |          12505.1       
      [33712, 208, 33712]   |           8726.4        |           8690.2       
      [33712, 212, 33712]   |          12677.6        |          12647.6       
      [33712, 224, 33712]   |           8755.1        |           8737.6       
      [33712, 228, 33712]   |          13656.3        |          13637.1       
      [33712, 232, 33712]   |           9556.6        |           9553.5       
      [33712, 236, 33712]   |          13670.5        |          13657.4       
      [33712, 244, 33712]   |          13676.9        |          13668.2       
      [33712, 252, 33712]   |          13867.6        |          13823.5       
      [33712, 260, 33712]   |          14612.3        |          14687.2       
      [33712, 264, 33712]   |          10368.9        |          10396.5       
      [33712, 268, 33712]   |          14810.6        |          14959.8       
      [33712, 272, 33712]   |          10372.5        |          10327.6       
      [33712, 280, 33712]   |          10452.8        |          10497.8       
      [33712, 292, 33712]   |          16117.2        |          16266.1       
      [33712, 296, 33712]   |          11204.4        |          11235.2       
      [33712, 300, 33712]   |          16402.0        |          16310.5       
      [33712, 304, 33712]   |          11227.4        |          11214.1       
      [33712, 316, 33712]   |          16937.2        |          16768.1       
      [33712, 332, 33712]   |          17428.4        |          17447.0       
      [33712, 336, 33712]   |          12046.1        |          12061.7       
      [33712, 356, 33712]   |          18870.6        |          18891.4       
      [33712, 376, 33712]   |          12934.1        |          12965.4       
      [33712, 384, 33712]   |          12950.2        |          12906.2       
      [33712, 388, 33712]   |          15976.2        |          15938.2       
      [33712, 400, 33712]   |          13677.3        |          13633.6       
      [33712, 408, 33712]   |          13747.2        |          13765.2       
      [33712, 428, 33712]   |          17035.4        |          17016.0       
      [33712, 432, 33712]   |          14486.4        |          14411.2       
      [33712, 436, 33712]   |          17119.7        |          17020.8       
      [33712, 440, 33712]   |          14586.4        |          14516.6       
      [33712, 456, 33712]   |          15278.9        |          15307.9       
      [33712, 464, 33712]   |          15311.8        |          15316.5       
      [33712, 476, 33712]   |          18201.9        |          18247.6       
      [33712, 484, 33712]   |          18922.6        |          18872.3       
      [33712, 488, 33712]   |          16647.0        |          16697.8       
      [33712, 496, 33712]   |          16589.0        |          16555.5       
      [33712, 500, 33712]   |          18945.4        |          18933.5       
      [33712, 504, 33712]   |          16828.1        |          16818.0       
      [33712, 508, 33712]   |          19175.0        |          19054.8       
      [33712, 512, 33712]   |          16536.6        |          16524.6       
      [33712, 1344, 33712]  |          37330.1        |          37344.8       
      [33712, 1728, 33712]  |          47447.8        |          47423.4       
      [33712, 1768, 33712]  |          50014.2        |          49990.6       
      [33712, 1848, 33712]  |          52038.3        |          51998.5       
      [33712, 1920, 33712]  |          52574.4        |          52632.0       
      [33712, 2128, 33712]  |          59376.5        |          58932.5       
      [33712, 2176, 33712]  |          59350.5        |          59383.5       
      [33712, 2232, 33712]  |          62838.1        |          62943.6       
      [33712, 2240, 33712]  |          60938.7        |          60824.0       
      [33712, 2304, 33712]  |          63049.4        |          63112.8       
      [33712, 2560, 33712]  |          70360.1        |          70197.7       
      [33712, 2640, 33712]  |          73268.7        |          73339.2       
      [33712, 2688, 33712]  |          73251.2        |          73442.9       
      [33712, 2816, 33712]  |          77123.7        |          76660.3       
      [33712, 2880, 33712]  |          78635.3        |          78678.5       
      [33712, 2968, 33712]  |          83524.1        |          83539.0       
      [33712, 3000, 33712]  |          84358.6        |          84087.5       
      [33712, 3040, 33712]  |          82465.9        |          82516.3       
      [33712, 3136, 33712]  |          84777.3        |          84606.9       
      [33712, 3328, 33712]  |          89779.2        |          89763.5       
      [33712, 3344, 33712]  |          92409.3        |          92439.3       
      [33712, 3456, 33712]  |          93355.0        |          93644.7       
      [33712, 3584, 33712]  |          97531.3        |          97232.2       
      [33712, 3600, 33712]  |          99637.8        |         100168.8       
      [33712, 3632, 33712]  |         100140.2        |         100962.4       
      [33712, 3744, 33712]  |         101100.4        |         101534.4       
      [33712, 3792, 33712]  |         103149.1        |         104673.6       
      [33712, 3816, 33712]  |         105714.8        |         105583.3       
      [33712, 3848, 33712]  |         106888.6        |         106989.3       
      [33712, 3904, 33712]  |         104329.1        |         105668.2       
      [33712, 3920, 33712]  |         106907.8        |         106922.6       
      [33712, 3952, 33712]  |         107769.5        |         107873.2       
      [33712, 3960, 33712]  |         110442.3        |         109465.4       
      [33712, 4096, 33712]  |         110393.1        |         109513.2       
      [33712, 4160, 33712]  |         111759.8        |         112966.5       
      [33712, 4224, 33712]  |         114123.0        |         113426.6       
      [33712, 4264, 33712]  |         118744.0        |         117847.3       
      [33712, 4280, 33712]  |         119227.4        |         118713.4       
      [33712, 4320, 33712]  |         116240.2        |         117310.2       
      [33712, 4352, 33712]  |         117725.6        |         118087.1       
      [33712, 4360, 33712]  |         122770.5        |         121504.9       
      [33712, 4368, 33712]  |         119800.4        |         120077.4       
      [33712, 4392, 33712]  |         122391.5        |         122174.3       
      [33712, 4400, 33712]  |         121185.4        |         120963.6       
      [33712, 4416, 33712]  |         118937.8        |         119511.6       
      [33712, 4440, 33712]  |         122991.9        |         124053.7       
      [33712, 4464, 33712]  |         121339.0        |         123248.8       
      [33712, 4480, 33712]  |         120723.5        |         121008.5       
      [33712, 4488, 33712]  |         126070.9        |         126477.2       
      [33712, 4536, 33712]  |         126556.6        |         126097.4       
      [33712, 4600, 33712]  |         128584.5        |         129145.3       
      [33712, 4608, 33712]  |         125072.4        |         124754.4       
      [33712, 4656, 33712]  |         129079.8        |         128587.5       
      [33712, 4664, 33712]  |         131088.8        |         130452.9       
      [33712, 4680, 33712]  |         131576.4        |         131362.1       
      [33712, 4704, 33712]  |         127283.6        |         127885.7       
      [33712, 4720, 33712]  |         129786.9        |         128857.6       
      [33712, 4736, 33712]  |         127659.3        |         128210.3       
      [33712, 4752, 33712]  |         130731.9        |         130995.7       
      [33712, 4760, 33712]  |         131974.8        |         132360.3       
      [33712, 4784, 33712]  |         131976.3        |         131774.9       
      [33712, 4800, 33712]  |         131262.4        |         130702.0       
      [33712, 4816, 33712]  |         131858.9        |         132331.3       
      [33712, 4824, 33712]  |         134461.4        |         134444.4       
      [33712, 4840, 33712]  |         136516.1        |         135704.7       
      [33712, 4864, 33712]  |         131120.3        |         132027.1       
      [33712, 4872, 33712]  |         137414.0        |         137195.6       
      [33712, 4880, 33712]  |         133910.4        |         134152.7       
      [33712, 4888, 33712]  |         136850.6        |         136910.0       
      [33712, 4896, 33712]  |         132675.7        |         132945.7       
      [33712, 4920, 33712]  |         137399.5        |         137938.8       
      [33712, 4928, 33712]  |         133400.9        |         133545.8       
      [33712, 4960, 33712]  |         134370.7        |         133955.8       
      [33712, 4968, 33712]  |         139415.6        |         139181.9       
      [33712, 4992, 33712]  |         135894.3        |         134947.4       
      [33712, 5000, 33712]  |         140289.3        |         140181.0       
      [33712, 5016, 33712]  |         140744.3        |         139294.4       
      [33712, 5032, 33712]  |         141529.4        |         141452.7       
      [33712, 5040, 33712]  |         138802.4        |         138069.7       
      [33712, 5056, 33712]  |         136077.1        |         136190.3       
      [33712, 5088, 33712]  |         137578.0        |         137825.3       
      [33712, 5096, 33712]  |         143425.7        |         143211.2       
      [33712, 5104, 33712]  |         140563.2        |         139638.5       
      [33712, 5120, 33712]  |         139875.7        |         139118.4       

Times are in microseconds (us).

@stas00
Copy link
Contributor

stas00 commented Nov 6, 2021

@stas00 Note that the behavior after the previous PR can be replicated with torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

Could we please document all these nuances. Perhaps adding a new dedicated doc that speaks specifically to precision vs performance control? Or adding it to the performance doc?

It could then include the tf32 enabling way as well from one of the recent PRs.

Thank you!

@stas00
Copy link
Contributor

stas00 commented Nov 6, 2021

I found the most relevant doc for this change: https://github.com/pytorch/pytorch/blob/master/docs/source/notes/numerical_accuracy.rst . So may be it should belong there and adding an xref from cuda.rst?

@ngimel
Copy link
Collaborator

ngimel commented Nov 6, 2021

I agree with @stas00, it makes sense to move the main portion of the docs to numerical_accuracy and expand it to mention that most of the math for gemms is done in fp32 precision, but, if reduced precision reduction is allowed, some intermediate results can be truncated to low precision, and cross-link it from cuda. Does this apply to bf16 also, btw? It's harder to establish because bf16 will only truncate mantissa, there won't be glaring overflows there.

@eqy
Copy link
Collaborator Author

eqy commented Nov 6, 2021

I agree with @stas00, it makes sense to move the main portion of the docs to numerical_accuracy and expand it to mention that most of the math for gemms is done in fp32 precision, but, if reduced precision reduction is allowed, some intermediate results can be truncated to low precision, and cross-link it from cuda. Does this apply to bf16 also, btw? It's harder to establish because bf16 will only truncate mantissa, there won't be glaring overflows there.

Since the original change was only for at::Half GEMMs I don't believe it would affect bf16. No one has raised any numerical alarms there yet AFAIK...

fp16 GEMMs are potentially done with reduced precision reductions (e.g., in fp16 rather than fp32). This reduction in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow.

Some example benchmark data on V100
.. code::
Copy link
Collaborator

@crcrpar crcrpar Nov 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit-picking] The table rendered by torch.utils.benchmark seems compatible with Sphinx (see the attached capture), which obviates this .. code:: directive.
Screenshot from 2021-11-06 15-09-06

Or fix the indentation of this table including "(time ...)".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this please

@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 8, 2021
Reduced Precision Reduction in FP16 GEMMs
-----------------------------------------

fp16 GEMMs are potentially done with reduced precision reductions (e.g., in fp16 rather than fp32). This reduction in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of the GEMM accumulation is still done in fp32 precision, there are only a few truncations that are done, so can you please make the wording more accurate to not imply that all the accumulation is done in fp16?

fp16 GEMMs are potentially done with reduced precision reductions (e.g., in fp16 rather than fp32). This reduction in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow.

Some example benchmark data on V100
.. code::
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this please

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 790763b.

pytorchmergebot pushed a commit that referenced this pull request Dec 21, 2022
…16 GEMM (#89172)

Essentially the same change as #67946, except that the default is to disallow reduced precision reductions in `BFloat16` GEMMs (for now). If performance is severely regressed, we can change the default, but this option appears to be necessary to pass some `addmm` `BFloat16` tests on H100.

CC @ptrblck @ngimel
Pull Request resolved: #89172
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants