Skip to content

Fatal Python error: Python memory allocator called without holding the GIL (with debug build of python) #1624

@gchanan

Description

@gchanan

I built a debug version of python 3.6.1 (via "./configure --with-pydebug") and ran TestAutograd and the following error came up:

[~/local/pytorch6] python3 test/test_autograd.py TestAutograd.test_grad
Fatal Python error: Python memory allocator called without holding the GIL

Thread 0x00007fc8f4f8f700 (most recent call first):

Thread 0x00007fc8f5790700 (most recent call first):

Current thread 0x00007fc8f5f91700 (most recent call first):

Thread 0x00007fc92d4ae740 (most recent call first):
  File "/data/users/gchanan/pytorch6/torch/autograd/__init__.py", line 153 in grad
  File "test/test_autograd.py", line 184 in test_grad
  File "/data/users/gchanan/python3/lib/python3.6/unittest/case.py", line 601 in run
  File "/data/users/gchanan/python3/lib/python3.6/unittest/case.py", line 649 in __call__
  File "/data/users/gchanan/python3/lib/python3.6/unittest/suite.py", line 122 in run
  File "/data/users/gchanan/python3/lib/python3.6/unittest/suite.py", line 84 in __call__
  File "/data/users/gchanan/python3/lib/python3.6/unittest/suite.py", line 122 in run
  File "/data/users/gchanan/python3/lib/python3.6/unittest/suite.py", line 84 in __call__
  File "/data/users/gchanan/python3/lib/python3.6/unittest/runner.py", line 176 in run
  File "/data/users/gchanan/python3/lib/python3.6/unittest/main.py", line 255 in runTests
  File "/data/users/gchanan/python3/lib/python3.6/unittest/main.py", line 94 in __init__
  File "/data/users/gchanan/pytorch6/test/common.py", line 30 in run_tests
  File "test/test_autograd.py", line 1717 in <module>
Aborted (core dumped)

Running this through gdb I get the following backtrace:

(gdb) bt
#0  0x00007ffff712c1d7 in raise () at /lib64/libc.so.6
#1  0x00007ffff712d8c8 in abort () at /lib64/libc.so.6
#2  0x0000000000420e2f in Py_FatalError (msg=msg@entry=0x5c0b70 "Python memory allocator called without holding the GIL") at Python/pylifecycle.c:1457
#3  0x000000000041d5b1 in _PyMem_DebugCheckGIL () at Objects/obmalloc.c:1972
#4  0x000000000041daa2 in _PyMem_DebugFree (ctx=0x880fd0 <_PyMem_Debug+48>, ptr=0x936420) at Objects/obmalloc.c:1994
#5  0x000000000041e78d in PyMem_Free (ptr=<optimized out>) at Objects/obmalloc.c:442
#6  0x000000000043a9b4 in _PyFaulthandler_Fini () at ./Modules/faulthandler.c:1369
#7  0x0000000000420e19 in Py_FatalError (msg=msg@entry=0x5c0b70 "Python memory allocator called without holding the GIL") at Python/pylifecycle.c:1431
#8  0x000000000041d5b1 in _PyMem_DebugCheckGIL () at Objects/obmalloc.c:1972
#9  0x000000000041d5ec in _PyMem_DebugMalloc (ctx=0x881000 <_PyMem_Debug+96>, nbytes=104) at Objects/obmalloc.c:1980
#10 0x000000000041e83b in PyObject_Malloc (size=size@entry=104) at Objects/obmalloc.c:479
#11 0x0000000000436e4d in _PyObject_GC_Alloc (use_calloc=use_calloc@entry=0, basicsize=basicsize@entry=80) at Modules/gcmodule.c:1714
#12 0x0000000000437382 in _PyObject_GC_Malloc (basicsize=basicsize@entry=80) at Modules/gcmodule.c:1736
#13 0x00000000004a8443 in PyType_GenericAlloc (type=0x1145908, nitems=0) at Objects/typeobject.c:936
#14 0x00007fffef4c7e92 in THPVariable_NewWithVar(PyTypeObject*, std::shared_ptr<torch::autograd::Variable>) (type=<optimized out>, var=...) at torch/csrc/autograd/python_variable.cpp:23
#15 0x00007fffef4c8f7b in THPVariable_Wrap(std::shared_ptr<torch::autograd::Variable> const&) (var=...) at torch/csrc/autograd/python_variable.cpp:38
#16 0x00007fffef4cab2a in __lambda1::operator() (grads=..., _unused=<optimized out>, __closure=0x1312900) at torch/csrc/autograd/python_engine.cpp:187
#17 0x00007fffef4cab2a in std::_Function_handler<bool(torch::autograd::Function*, std::vector<std::shared_ptr<torch::autograd::Variable>, std::allocator<std::shared_ptr<torch::autograd::Variable> > >&), THPEngine_run_backward(THPEngine*, PyObject*, PyObject*)::__lambda1>::_M_invoke(const std::_Any_data &, torch::autograd::Function *, std::vector<std::shared_ptr<torch::autograd::Variable>, std::allocator<std::shared_ptr<torch::autograd::Variable> > > &) (__functor=..., __args#0=<optimized out>, __args#1=...) at /usr/include/c++/4.8.2/functional:2057
#18 0x00007fffef4b5706 in std::function<bool (torch::autograd::Function*, std::vector<std::shared_ptr<torch::autograd::Variable>, std::allocator<std::shared_ptr<torch::autograd::Variable> > >&)>::operator()(torch::autograd::Function*, std::vector<std::shared_ptr<torch::autograd::Variable>, std::allocator<std::shared_ptr<torch::autograd::Variable> > >&) const (this=<optimized out>, __args#0=__args#0@entry=0x12fe868, __args#1=...) at /usr/include/c++/4.8.2/functional:2471
#19 0x00007fffef4b1f83 in torch::autograd::call_function(torch::autograd::FunctionTask&) (task=...) at torch/csrc/autograd/engine.cpp:152
#20 0x00007fffef4b3345 in torch::autograd::Engine::evaluate_function(torch::autograd::FunctionTask&) (this=this@entry=0x7ffff01671a0 <engine>, task=...) at torch/csrc/autograd/engine.cpp:160
#21 0x00007fffef4b41e3 in torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::ReadyQueue>) (this=this@entry=0x7ffff01671a0 <engine>, queue=...) at torch/csrc/autograd/engine.cpp:110
#22 0x00007fffef4cd7f7 in PythonEngine::thread_main(std::shared_ptr<torch::autograd::ReadyQueue>) (this=0x7ffff01671a0 <engine>, queue=...) at torch/csrc/autograd/python_engine.cpp:23
#23 0x00007fffef4b57ca in std::_Mem_fn<void (torch::autograd::Engine::*)(std::shared_ptr<torch::autograd::ReadyQueue>)>::operator()<std::shared_ptr<torch::autograd::ReadyQueue>, void> (__object=<optimized out>, this=<optimized out>) at /usr/include/c++/4.8.2/functional:601
#24 0x00007fffef4b57ca in std::_Bind_simple<std::_Mem_fn<void (torch::autograd::Engine::*)(std::shared_ptr<torch::autograd::ReadyQueue>)>(torch::autograd::Engine*, std::shared_ptr<torch::autograd::ReadyQueue>)>::_M_invoke<0ul, 1ul> (this=<optimized out>) at /usr/include/c++/4.8.2/functional:1732
#25 0x00007fffef4b57ca in std::_Bind_simple<std::_Mem_fn<void (torch::autograd::Engine::*)(std::shared_ptr<torch::autograd::ReadyQueue>)> (torch::autograd::Engine*, std::shared_ptr<torch::autograd::ReadyQueue>)>::operator()() (this=<optimized out>) at /usr/include/c++/4.8.2/functional:1720
#26 0x00007fffef4b57ca in std::thread::_Impl<std::_Bind_simple<std::_Mem_fn<void (torch::autograd::Engine::*)(std::shared_ptr<torch::autograd::ReadyQueue>)> (torch::autograd::Engine*, std::shared_ptr<torch::autograd::ReadyQueue>)> >::_M_run() (this=<optimized out>) at /usr/include/c++/4.8.2/thread:115
#27 0x00007fffd3a58230 in  () at /lib64/libstdc++.so.6
#28 0x00007ffff7bc8dc5 in start_thread () at /lib64/libpthread.so.0
#29 0x00007ffff71ee73d in clone () at /lib64/libc.so.6

The following commit seems to fix the issue: gchanan@eef93b7 although I don't know enough about the autograd engine at this point to say if that's the correct fix or not.

Even with the above change I can't get through the test suite with a debug build of python3.6.1; I'll file more issues as I find them (perhaps we should have a jenkins job running this?)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions