Skip to content

torch._scaled_mm row-wise hits CUDA invalid memory access when M % 256 != 0 #133334

@lw

Description

@lw

🐛 Describe the bug

(conda_env) lcw@cr1-p548xlarge-19:~$ PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer --print-limit=1 --num-callers-host=10 ipython
========= COMPUTE-SANITIZER
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.22.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch

In [2]: a = torch.randn((8352, 4096), device="cuda", dtype=torch.bfloat16).to(torch.float8_e4m3fn)

In [3]: b = torch.randn((1536, 4096), device="cuda", dtype=torch.bfloat16).to(torch.float8_e4m3fn)

In [4]: scales_a = torch.randn(8352, device="cuda", dtype=torch.float)

In [5]: scales_b = torch.randn(1536, device="cuda", dtype=torch.float)

In [6]: torch._scaled_mm(a, b.t(), scale_a=scales_a[:,None], scale_b=scales_b[None,:], out_dtype=torch.bfloat16, use_fast_accum=True)
Out[6]: ========= Invalid __global__ read of size 4 bytes
=========     at void cutlass::device_kernel<cutlass::gemm::kernel::GemmUniversal<cute::tuple<int, int, int>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<(int)6, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum>, cute::tuple<cute::C<(int)128>, cute::C<(int)128>, cute::C<(int)128>>, cutlass::float_e4m3_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::float_e4m3_t, cute::tuple<long, cute::C<(int)1>, long>, cute::TiledMMA<cute::MMA_Atom<cute::SM90_64x128x32_F32E4M3E4M3_SS_TN<(cute::GMMA::ScaleIn)1, (cute::GMMA::ScaleIn)1>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)8>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)128>>, cute::tuple<cute::C<(int)128>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)8>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)128>>, cute::tuple<cute::C<(int)128>, cute::C<(int)1>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<(int)8, (int)2, (int)16, (bool)1>, cute::tuple<cute::C<(int)128>, cute::C<(int)128>, cute::C<(int)128>>, cute::tuple<cute::C<(int)64>, cute::C<(int)32>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90ColBroadcast<(int)0, cute::tuple<cute::C<(int)128>, cute::C<(int)128>, cute::C<(int)128>>, float, cute::tuple<cute::C<(int)1>, cute::C<(int)0>, cute::C<(int)0>>, (int)4, (bool)1>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, float, float, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90RowBroadcast<(int)2, cute::tuple<cute::C<(int)128>, cute::C<(int)128>, cute::C<(int)128>>, float, cute::tuple<cute::C<(int)0>, cute::C<(int)1>, cute::C<(int)0>>, (int)4, (bool)1>, cutlass::epilogue::fusion::Sm90AccFetch>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM75_U32x4_LDSM_N, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM90_U32x4_STSM_N>, void, void>>(T1::Params)+0x2a50
=========     by thread (320,0,0) in block (1,54,0)
=========     Address 0x7fcabf608280 is out of bounds
=========     and is 1 bytes after the nearest allocation at 0x7fcabf600000 of size 33408 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x332833]
=========                in /lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0x14cb8]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/../../../../libcudart.so.12
=========     Host Frame:cudaLaunchKernelExC [0x6c163]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/../../../../libcudart.so.12
=========     Host Frame:void (anonymous namespace)::f8f8bf16_rowwise_impl<128, 128, 128, 2, 1, 1, true, true, false, cutlass::float_e4m3_t, float>(at::Tensor, at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>, at::Tensor) [0x272926b]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:void (anonymous namespace)::dispatch_fp8_rowwise_kernel<cutlass::float_e4m3_t, true, false, float>(at::Tensor, at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>, at::Tensor) [0x2729adc]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::cuda::detail::f8f8bf16_rowwise(at::Tensor, at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>, bool, at::Tensor&) [0x26ec119]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::_scaled_mm_out_cuda(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool, at::Tensor&) [0x37819d8]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::_scaled_mm_cuda(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool) [0x3782e73]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___scaled_mm(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool) [0x3451f0e]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___scaled_mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x3598a51]
=========                in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at /opt/conda/conda-bld/pytorch_1721288503779/work/c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fcdd038c7b6 in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fcdd033a504 in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: ...

========= Error: process didn't terminate successfully
========= Target application returned an error
========= ERROR SUMMARY: 644 errors
========= ERROR SUMMARY: 643 errors were not printed. Use --print-limit option to adjust the number of printed errors

The address on which the IMA occurs is the one of the scale_a tensor, and looking at all the attempted reads it seems the kernel is reading up to the next 256 element boundary of that tensor. Indeed, when rounding M up to the next multiple of 256 the issue doesn't seem to occur.

By default the issue appears non-deterministically in PyTorch because the caching allocator usually provides larger allocations than the tensor needs, thus out-of-bound accesses will not trigger an IMA. By disabling the caching allocator with PYTORCH_NO_CUDA_MEMORY_CACHING=1 the repro becomes more reliable.

If I recompile that kernel and pass the -lineinfo flag to nvcc the issue seems to disappear.

I found a reference in PyTorch that could hint to this bug being hit in the past:

Note: This method is a workaround for CUDA Errors that seemingly non-deterministically
occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term.
it might make sense to check on newer Cutlass releases whether it makes sense to keep
returning True in certain cases or whether it becomes unneccessary.

Versions

PyTorch version: 2.5.0.dev20240718
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.29.5
Libc version: glibc-2.31

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1064-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             192
On-line CPU(s) list:                0-191
Thread(s) per core:                 2
Core(s) per socket:                 48
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              1
Model name:                         AMD EPYC 7R13 Processor
Stepping:                           1
CPU MHz:                            2650.000
BogoMIPS:                           5300.00
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          3 MiB
L1i cache:                          3 MiB
L2 cache:                           48 MiB
L3 cache:                           384 MiB
NUMA node0 CPU(s):                  0-47,96-143
NUMA node1 CPU(s):                  48-95,144-191
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] lovely-numpy==0.2.11
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.5.0.dev20240718
[pip3] torchmetrics==0.10.3
[pip3] torchvision==0.20.0.dev20240718
[pip3] triton==3.0.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @yanbing-j @vkuzo @albanD @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: crashProblem manifests as a hard crash, as opposed to a RuntimeErrormodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions