-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
(conda_env) lcw@cr1-p548xlarge-19:~$ PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer --print-limit=1 --num-callers-host=10 ipython
========= COMPUTE-SANITIZER
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.22.1 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import torch
In [2]: a = torch.randn((8352, 4096), device="cuda", dtype=torch.bfloat16).to(torch.float8_e4m3fn)
In [3]: b = torch.randn((1536, 4096), device="cuda", dtype=torch.bfloat16).to(torch.float8_e4m3fn)
In [4]: scales_a = torch.randn(8352, device="cuda", dtype=torch.float)
In [5]: scales_b = torch.randn(1536, device="cuda", dtype=torch.float)
In [6]: torch._scaled_mm(a, b.t(), scale_a=scales_a[:,None], scale_b=scales_b[None,:], out_dtype=torch.bfloat16, use_fast_accum=True)
Out[6]: ========= Invalid __global__ read of size 4 bytes
========= at void cutlass::device_kernel<cutlass::gemm::kernel::GemmUniversal<cute::tuple<int, int, int>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<(int)6, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum>, cute::tuple<cute::C<(int)128>, cute::C<(int)128>, cute::C<(int)128>>, cutlass::float_e4m3_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::float_e4m3_t, cute::tuple<long, cute::C<(int)1>, long>, cute::TiledMMA<cute::MMA_Atom<cute::SM90_64x128x32_F32E4M3E4M3_SS_TN<(cute::GMMA::ScaleIn)1, (cute::GMMA::ScaleIn)1>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)8>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)128>>, cute::tuple<cute::C<(int)128>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)8>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)128>>, cute::tuple<cute::C<(int)128>, cute::C<(int)1>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<(int)8, (int)2, (int)16, (bool)1>, cute::tuple<cute::C<(int)128>, cute::C<(int)128>, cute::C<(int)128>>, cute::tuple<cute::C<(int)64>, cute::C<(int)32>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, long>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90ColBroadcast<(int)0, cute::tuple<cute::C<(int)128>, cute::C<(int)128>, cute::C<(int)128>>, float, cute::tuple<cute::C<(int)1>, cute::C<(int)0>, cute::C<(int)0>>, (int)4, (bool)1>, cutlass::epilogue::fusion::Sm90TreeVisitor<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, float, float, (cutlass::FloatRoundStyle)2, void>, cutlass::epilogue::fusion::Sm90RowBroadcast<(int)2, cute::tuple<cute::C<(int)128>, cute::C<(int)128>, cute::C<(int)128>>, float, cute::tuple<cute::C<(int)0>, cute::C<(int)1>, cute::C<(int)0>>, (int)4, (bool)1>, cutlass::epilogue::fusion::Sm90AccFetch>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM75_U32x4_LDSM_N, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)32>>, cute::tuple<cute::C<(int)32>, cute::C<(int)1>>>>, cute::SM90_U32x4_STSM_N>, void, void>>(T1::Params)+0x2a50
========= by thread (320,0,0) in block (1,54,0)
========= Address 0x7fcabf608280 is out of bounds
========= and is 1 bytes after the nearest allocation at 0x7fcabf600000 of size 33408 bytes
========= Saved host backtrace up to driver entry point at kernel launch time
========= Host Frame: [0x332833]
========= in /lib/x86_64-linux-gnu/libcuda.so.1
========= Host Frame: [0x14cb8]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/../../../../libcudart.so.12
========= Host Frame:cudaLaunchKernelExC [0x6c163]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/../../../../libcudart.so.12
========= Host Frame:void (anonymous namespace)::f8f8bf16_rowwise_impl<128, 128, 128, 2, 1, 1, true, true, false, cutlass::float_e4m3_t, float>(at::Tensor, at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>, at::Tensor) [0x272926b]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
========= Host Frame:void (anonymous namespace)::dispatch_fp8_rowwise_kernel<cutlass::float_e4m3_t, true, false, float>(at::Tensor, at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>, at::Tensor) [0x2729adc]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
========= Host Frame:at::cuda::detail::f8f8bf16_rowwise(at::Tensor, at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>, bool, at::Tensor&) [0x26ec119]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
========= Host Frame:at::native::_scaled_mm_out_cuda(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool, at::Tensor&) [0x37819d8]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
========= Host Frame:at::native::_scaled_mm_cuda(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool) [0x3782e73]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
========= Host Frame:at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___scaled_mm(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool) [0x3451f0e]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
========= Host Frame:c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___scaled_mm>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<c10::ScalarType>, bool> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x3598a51]
========= in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
=========
terminate called after throwing an instance of 'c10::Error'
what(): CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /opt/conda/conda-bld/pytorch_1721288503779/work/c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fcdd038c7b6 in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fcdd033a504 in /home/lcw/conda_env/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: ...
========= Error: process didn't terminate successfully
========= Target application returned an error
========= ERROR SUMMARY: 644 errors
========= ERROR SUMMARY: 643 errors were not printed. Use --print-limit option to adjust the number of printed errors
The address on which the IMA occurs is the one of the scale_a tensor, and looking at all the attempted reads it seems the kernel is reading up to the next 256 element boundary of that tensor. Indeed, when rounding M up to the next multiple of 256 the issue doesn't seem to occur.
By default the issue appears non-deterministically in PyTorch because the caching allocator usually provides larger allocations than the tensor needs, thus out-of-bound accesses will not trigger an IMA. By disabling the caching allocator with PYTORCH_NO_CUDA_MEMORY_CACHING=1 the repro becomes more reliable.
If I recompile that kernel and pass the -lineinfo flag to nvcc the issue seems to disappear.
I found a reference in PyTorch that could hint to this bug being hit in the past:
pytorch/torch/_inductor/codegen/cuda/gemm_template.py
Lines 548 to 551 in d2e9a8b
| Note: This method is a workaround for CUDA Errors that seemingly non-deterministically | |
| occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term. | |
| it might make sense to check on newer Cutlass releases whether it makes sense to keep | |
| returning True in certain cases or whether it becomes unneccessary. |
Versions
PyTorch version: 2.5.0.dev20240718
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.29.5
Libc version: glibc-2.31
Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1064-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 192
On-line CPU(s) list: 0-191
Thread(s) per core: 2
Core(s) per socket: 48
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7R13 Processor
Stepping: 1
CPU MHz: 2650.000
BogoMIPS: 5300.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3 MiB
L1i cache: 3 MiB
L2 cache: 48 MiB
L3 cache: 384 MiB
NUMA node0 CPU(s): 0-47,96-143
NUMA node1 CPU(s): 48-95,144-191
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid
Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] lovely-numpy==0.2.11
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.5.0.dev20240718
[pip3] torchmetrics==0.10.3
[pip3] torchvision==0.20.0.dev20240718
[pip3] triton==3.0.0
[conda] Could not collect
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @yanbing-j @vkuzo @albanD @penguinwu