Skip to content

Importing torchvision.models.detection.faster_rcnn when on CPU results in GIL deadlock when torch tensor from numpy.array is deallocated #83101

@Chekov2k

Description

@Chekov2k

🐛 Describe the bug

We are using pybind11 and torch to create a C++ binary with an embedded interpreter. When we are trying to use pytorch in a docker container with a CPU device only combined with a torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn we can trigger a deadlock in the resource cleanup of a torch tensor. Creating torch tensor from a numpy.array torch.as_tensor(array).float().to(device) seems to trigger the deadlock when the resulting tensor is deallocated.

Is that expected behaviour? Are we missing an initialisation or something like that?

To recreate the problem compile the code pasted below like this:

mkdir build
cd build
ln -s ../demo.py .
cmake ../
make
./demo

demo.py (adjusted test script name to avoid clash with built in module)

# -*- coding: utf-8 -*-

import torch
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
import numpy as np

def test():
    device = torch.device("cpu")
    array = np.random.randint(50,75,(2,2), dtype='int64')
    for _ in range(5):
        print(torch.as_tensor(array).float().to(device))

main.cpp

#include <pybind11/embed.h>
#include <pybind11/pybind11.h>

#include <iostream>
#include <thread>

void Do() {
    std::cout << "Do start" << std::endl;
    pybind11::gil_scoped_acquire acquire;
    auto testModule = pybind11::module::import("demo");
    auto testMethod = testModule.attr("test");
    testMethod();
    std::cout << "Do stop" << std::endl;
  }

void run() {
  for (std::size_t index = 0, end = 5; index < end; ++index) Do();
}

int main(int pArgc, const char *const *pArgv) {
  pybind11::initialize_interpreter(false);
  // https://github.com/pybind/pybind11/issues/2197
  pybind11::module::import("threading");

  auto thread = std::thread(&run);

  {
    pybind11::gil_scoped_release release;
    thread.join();
    std::cout << "joined" << std::endl;
  }
}

CMakeList.txt

cmake_minimum_required(VERSION 3.11...3.18)
project(demo LANGUAGES CXX)

# threads
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)

find_package(pybind11 REQUIRED)
include_directories(SYSTEM ${pybind11_INCLUDE_DIRS})

add_executable(demo main.cpp)
target_link_libraries(demo PUBLIC Threads::Threads)
target_link_libraries(demo PRIVATE pybind11::embed)

The output we would expect is this:

0b312b75efe7# ./demo
Do start
tensor([[62., 72.],
        [58., 73.]])
tensor([[62., 72.],
        [58., 73.]])
tensor([[62., 72.],
        [58., 73.]])
tensor([[62., 72.],
        [58., 73.]])
tensor([[62., 72.],
        [58., 73.]])
Do stop
Do start
tensor([[60., 65.],
        [71., 72.]])
tensor([[60., 65.],
        [71., 72.]])
tensor([[60., 65.],
        [71., 72.]])
tensor([[60., 65.],
        [71., 72.]])
tensor([[60., 65.],
        [71., 72.]])
Do stop
Do start
tensor([[65., 51.],
        [58., 57.]])
tensor([[65., 51.],
        [58., 57.]])
tensor([[65., 51.],
        [58., 57.]])
tensor([[65., 51.],
        [58., 57.]])
tensor([[65., 51.],
        [58., 57.]])
Do stop
Do start
tensor([[63., 74.],
        [52., 57.]])
tensor([[63., 74.],
        [52., 57.]])
tensor([[63., 74.],
        [52., 57.]])
tensor([[63., 74.],
        [52., 57.]])
tensor([[63., 74.],
        [52., 57.]])
Do stop
Do start
tensor([[56., 64.],
        [54., 61.]])
tensor([[56., 64.],
        [54., 61.]])
tensor([[56., 64.],
        [54., 61.]])
tensor([[56., 64.],
        [54., 61.]])
tensor([[56., 64.],
        [54., 61.]])
Do stop
joined

However, it gets stuck like this

0b312b75efe7# ./demo
Do start
/usr/local/lib/python3.8/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension:
  warn(f"Failed to load image Python extension: {e}")
tensor([[58., 55.],
        [50., 68.]])
tensor([[58., 55.],
        [50., 68.]])
tensor([[58., 55.],
        [50., 68.]])
tensor([[58., 55.],
        [50., 68.]])
tensor([[58., 55.],
        [50., 68.]])
Do stop
Do start

And the gdb stack trace shows a GIL acquire deadlock in the resource deallocation of the torch tensor

(gdb) thread apply all bt

Thread 5 (Thread 0xffff943151b0 (LWP 352)):
#0  0x0000ffffa4e8541c in pthread_cond_wait@@GLIBC_2.17 () from /lib/aarch64-linux-gnu/libpthread.so.0
#1  0x0000ffff9770a488 in blas_thread_server () from /usr/local/lib/python3.8/dist-packages/numpy/core/../../numpy.libs/libopenblas64_p-r0-9c1f2efe.3.20.so
#2  0x0000ffffa4e7e624 in start_thread () from /lib/aarch64-linux-gnu/libpthread.so.0
#3  0x0000ffffa4dd549c in ?? () from /lib/aarch64-linux-gnu/libc.so.6

Thread 4 (Thread 0xffff96b161b0 (LWP 351)):
#0  0x0000ffffa4e8541c in pthread_cond_wait@@GLIBC_2.17 () from /lib/aarch64-linux-gnu/libpthread.so.0
#1  0x0000ffff9770a488 in blas_thread_server () from /usr/local/lib/python3.8/dist-packages/numpy/core/../../numpy.libs/libopenblas64_p-r0-9c1f2efe.3.20.so
#2  0x0000ffffa4e7e624 in start_thread () from /lib/aarch64-linux-gnu/libpthread.so.0
#3  0x0000ffffa4dd549c in ?? () from /lib/aarch64-linux-gnu/libc.so.6

Thread 3 (Thread 0xffff973171b0 (LWP 350)):
#0  0x0000ffffa4e8541c in pthread_cond_wait@@GLIBC_2.17 () from /lib/aarch64-linux-gnu/libpthread.so.0
#1  0x0000ffff9770a488 in blas_thread_server () from /usr/local/lib/python3.8/dist-packages/numpy/core/../../numpy.libs/libopenblas64_p-r0-9c1f2efe.3.20.so
#2  0x0000ffffa4e7e624 in start_thread () from /lib/aarch64-linux-gnu/libpthread.so.0
#3  0x0000ffffa4dd549c in ?? () from /lib/aarch64-linux-gnu/libc.so.6

Thread 2 (Thread 0xffffa4a5a1b0 (LWP 348)):
#0  0x0000ffffa4e85788 in pthread_cond_timedwait@@GLIBC_2.17 () from /lib/aarch64-linux-gnu/libpthread.so.0
#1  0x0000ffffa535a030 in PyCOND_TIMEDWAIT (us=<optimized out>, mut=0xffffa56d1810 <_PyRuntime+1232>, cond=0xffffa56d17e0 <_PyRuntime+1184>) at ../Python/condvar.h:73
#2  take_gil (ceval=0xffffa56d1588 <_PyRuntime+584>, tstate=0xffffa0000f70) at ../Python/ceval_gil.h:206
#3  0x0000ffffa53598b8 in PyEval_AcquireThread (tstate=0xffffa0000f70) at ../Python/ceval.c:316
#4  0x0000ffff9f3b410c in pybind11::gil_scoped_acquire::gil_scoped_acquire() () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#5  0x0000ffff9fa907fc in std::_Function_handler<void (void*), torch::utils::tensor_from_numpy(_object*, bool)::{lambda(void*)#1}>::_M_invoke(std::_Any_data const&, void*&&) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#6  0x0000ffff9922afa4 in c10::deleteInefficientStdFunctionContext(void*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libc10.so
#7  0x0000ffff9f503a34 in c10::StorageImpl::release_resources() () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#8  0x0000ffff9923629c in c10::TensorImpl::release_resources() () from /usr/local/lib/python3.8/dist-packages/torch/lib/libc10.so
#9  0x0000ffff9f3b09a4 in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#10 0x0000ffff9f75a6ac in THPVariable_clear(THPVariable*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#11 0x0000ffff9f75a90c in THPVariable_subclass_dealloc(_object*) () from /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so
#12 0x0000ffffa520e6fc in _Py_DECREF (filename=<synthetic pointer>, lineno=4971, op=<optimized out>) at ../Include/object.h:478
#13 call_function (tstate=0xffffa29c8890, pp_stack=0xffffa4a596e0, oparg=<optimized out>, kwnames=0x0) at ../Python/ceval.c:4971
#14 0x0000ffffa520f630 in _PyEval_EvalFrameDefault (f=0xffffa4b72040, throwflag=<optimized out>) at ../Python/ceval.c:3486
#15 0x0000ffffa5218724 in function_code_fastcall (co=<optimized out>, args=<optimized out>, nargs=0, globals=<optimized out>) at ../Objects/call.c:284
#16 0x0000ffffa5431fe0 in PyVectorcall_Call (callable=0xffffa4a68550, tuple=<optimized out>, kwargs=<optimized out>) at ../Objects/call.c:200
#17 0x000000000041aea8 in pybind11::detail::simple_collector<(pybind11::return_value_policy)1>::call(_object*) const ()
#18 0x00000000004171b4 in pybind11::object pybind11::detail::object_api<pybind11::detail::accessor<pybind11::detail::accessor_policies::str_attr> >::operator()<(pybind11::return_value_policy)1>() const ()
#19 0x0000000000406768 in Do() ()
#20 0x0000000000406810 in run() ()
#21 0x000000000042a6ac in void std::__invoke_impl<void, void (*)()>(std::__invoke_other, void (*&&)()) ()
#22 0x000000000042a660 in std::__invoke_result<void (*)()>::type std::__invoke<void (*)()>(void (*&&)()) ()
#23 0x000000000042a5fc in void std::thread::_Invoker<std::tuple<void (*)()> >::_M_invoke<0ul>(std::_Index_tuple<0ul>) ()
#24 0x000000000042a5a8 in std::thread::_Invoker<std::tuple<void (*)()> >::operator()() ()
#25 0x000000000042a53c in std::thread::_State_impl<std::thread::_Invoker<std::tuple<void (*)()> > >::_M_run() ()
#26 0x0000ffffa504ebac in ?? () from /usr/local/lib/libstdc++.so.6
#27 0x0000ffffa4e7e624 in start_thread () from /lib/aarch64-linux-gnu/libpthread.so.0
#28 0x0000ffffa4dd549c in ?? () from /lib/aarch64-linux-gnu/libc.so.6

Thread 1 (Thread 0xffffa5716010 (LWP 345)):
#0  0x0000ffffa4e7f944 in __pthread_clockjoin_ex () from /lib/aarch64-linux-gnu/libpthread.so.0
#1  0x0000ffffa504ee40 in std::thread::join() () from /usr/local/lib/libstdc++.so.6
#2  0x00000000004068ac in main ()

Just in case if it is related to running in a docker container, using the Dockerfile below to build a container and run the code in there will show the deadlock

FROM ubuntu:22.04

RUN apt update && \
    apt install -y \
        python3-pip \
        build-essential \
        cmake \
        pkg-config

RUN pip3 install \
        pybind11 \
        numpy \
        torch \
        torchvision

COPY CMakeLists.txt demo.py main.cpp /test/

WORKDIR /build
RUN ln -s /test/demo.py .

RUN cmake /test/ -Dpybind11_DIR=/usr/local/lib/python3.10/dist-packages/pybind11/share/cmake/pybind11

RUN make

ENTRYPOINT ["/build/demo"]

Versions

0b312b75efe7# python3 collect_env.py
Collecting environment information...
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (aarch64)
GCC version: (GCC) 12.1.0
Clang version: Could not collect
CMake version: version 3.23.2
Libc version: glibc-2.31

Python version: 3.8.10 (default, Jun 22 2022, 20:18:18) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.104-linuxkit-aarch64-with-glibc2.29
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.971
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.1
[pip3] torch==1.12.1
[pip3] torchvision==0.13.1
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @jbschlosser @mruberry @rgommers

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cppRelated to C++ APImodule: deadlockProblems related to deadlocks (hang without exiting)module: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions