Skip to content

CTC compute-sanitizer error in ctc_loss_backward_log_beta_gpu_kernel #140777

@tjysdsg

Description

@tjysdsg

🐛 Describe the bug

compute-sanitizer reports many errors like this when running CTC backward pass. The input of CTC (linked below) is not strictly log probs, and target lengths are generally very short, with many zero-length elements.

========= Invalid __global__ read of size 8 bytes
=========     at void at::native::<unnamed>::ctc_loss_backward_log_beta_gpu_kernel<float, long>(T1 *, const T1 *, const long *, long, const T2 *, const long *, long, long, long, long, long, long, long, const long *, long, long, long)+0x70
=========     by thread (0,60,0) in block (0,8,0)
=========     Address 0x7f0605dfbde0 is out of bounds
=========     and is 32 bytes before the nearest allocation at 0x7f0605dfbe00 of size 8,600 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x334640]
=========                in /lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0x1498c]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
=========     Host Frame:cudaLaunchKernel [0x6bedb]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
=========     Host Frame:at::Tensor at::native::(anonymous namespace)::ctc_loss_backward_gpu_template<float, (c10::ScalarType)4>(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, long, bool) [0x228b02a]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::native::ctc_loss_backward_gpu(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, long, bool) [0x2285835]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___ctc_loss_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, long, bool) [0x33bec61]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, long, bool), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___ctc_loss_backward>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, long, bool> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x357c913]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:c10::OperatorHandle::redispatchBoxed(c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const [0x52d0c1b]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::autogradNotImplementedFallbackImpl(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x52cfc0f]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::_ops::_ctc_loss_backward::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, long, bool) [0x28d56bb]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::generated::CtcLossBackward0::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) [0x45aaa22]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::Node::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) [0x52e992b]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) [0x52e39e6]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) [0x52e4658]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) [0x52db5df]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) [0x870cac]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/lib/python3.11/site-packages/torch/lib/libtorch_python.so
=========     Host Frame:execute_native_thread_routine in /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/src/c++11/thread.cc:84 [0xdbbf4]
=========                in /share5/users/mark.tang/miniconda3/envs/uat/bin/../lib/libstdc++.so.6
=========     Host Frame:start_thread [0x8609]
=========                in /lib/x86_64-linux-gnu/libpthread.so.0
=========     Host Frame:clone [0x11f133]
=========                in /lib/x86_64-linux-gnu/libc.so.6
========= 

To reproduce:

export PYTORCH_NO_CUDA_MEMORY_CACHING=1
compute-sanitizer --tool memcheck --require-cuda-init=no --launch-timeout=0 python reproduce.py

Download https://github.com/tjysdsg/tjysdsg/blob/main/ctc-sanitizer.pt and run reproduce.py:

import torch
import torch.nn as nn


ctc = nn.CTCLoss(blank=0, reduction='none', zero_infinity=True)

logp, targets, input_lens, target_lens = torch.load("ctc-sanitizer.pt")
logp.requires_grad_(True)

logp, targets, input_lens, target_lens = (
    logp.to(device="cuda:0"), targets.to(device="cuda:0"),
    input_lens.to(device="cuda:0"), target_lens.to(device="cuda:0"),
)
loss = ctc(logp, targets, input_lens, target_lens).sum()
loss.backward()

Versions

I tried a combination of two machines and two pytorch package versions:

  1. Driver Version: 545.23.06, CUDA Version: 12.3, nvidia A100
  2. Driver Version: 535.129.03, CUDA Version: 12.2, nvidia H100

a. torch==2.4.1 -i https://download.pytorch.org/whl/cu121
b. torch==2.4.1 -i https://download.pytorch.org/whl/cu124

Sanitizer log:

sanitizer.txt

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7 @xmfan

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions