-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
Issue description
I found several errors in autograd:
- Segfault during
.backward()/.grad()when using hook (e.g. through.retain_grad()) on non-reachable tensor, whose grad is implicitely calculated (= 0) because it is an output of a function in the gradient graph but is independent of the backprop root tensor. (See code and traceback below.) - No hook is called if this non-reachable tensor is an output of an index operation (e.g.
a[0].register_hook()while root only depends ona[1]and completeahasrequires_grad.) (See code below.) That issue is not related to the others, but I encountered it in the same run, so I want to mention it here, too. - If such a hook is called (e.g.
a0.register_hook()froma0, a1 = a.unbind()), thegradargument isNonebut should be a tensor with0-values, because that is the actually used value for the requireda.gradand therefore should be modifiable. (See code below.)
Here is the traceback for the Segfault: (Notice the lines <= #6. But I think, the actual source of the problem is, that in this line, the inputs are not but should be initialized as variables with 0-values, as fallbacks if no function overrides them. Or am I wrong?)
I really would like to fix this, but to be honest, I'm not sure if I have enough expertise to do it right. So I open this issue for others. I hope, it helps.
Code example
import torch
a = torch.tensor([3., 5.], requires_grad=True)
# (A) With the following line, the callback for `a0` is not called.
# But it should be called (with `grad = 0`, implicitely calculated,
# because `a1` is independent of `a0`), because it is used
# for the required `a.grad` and therefore should be modifyable
# via hook. (Also see (C).)
# a0, a1 = a[0], a[1]
# (B) With the following line instead, the callback for `a0` is
# called, but with two errors:
# (B.a) The callback is called with `grad = None`, but it should be `= 0`
# (implicitely calculated, because `a1` is independent of `a0`).
# (Also see (C).)
# (B.b) After the callback, it results in a `Segmentation fault`.
# (Not, if no callback was registered.)
a0, a1 = a.unbind()
@a0.register_hook
def hook(grad):
print(grad)
# (C) When `grad` is `None`, returning a non-None replacement throws
# the Runtime Error "can't replace a None gradient with a non-None value".
# Therefore the current behaviour allows no modification at all for
# implicitely calculated `0`-gradients.
# return torch.tensor(1.)
a1.backward()
# Above errors occure no matter whether using `.backward()` or `.grad()`.
# torch.autograd.grad([a1], [a])
print(a.grad)Traceback
#0 0x00007fffeda1bfa7 in std::__atomic_base<unsigned long>::operator++ (this=0x8) at /usr/include/c++/5/bits/atomic_base.h:296
#1 0x00007fffeda2978b in c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::retain_ (this=0x7ffff39356f0)
at /home/jk/workspace/projects/ml/www/pytorch/torch/lib/tmp_install/include/ATen/core/intrusive_ptr.h:163
#2 0x00007fffeda28cc0 in c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::intrusive_ptr (this=0x7ffff39356f0, rhs=...)
at /home/jk/workspace/projects/ml/www/pytorch/torch/lib/tmp_install/include/ATen/core/intrusive_ptr.h:211
#3 0x00007fffedae1bd1 in c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::operator=<at::TensorImpl, at::UndefinedTensorImpl>(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl> const&) & (this=0x7fffe00012a0, rhs=...) at /home/jk/workspace/projects/ml/www/pytorch/torch/lib/tmp_install/include/ATen/core/intrusive_ptr.h:252
#4 0x00007fffedae0d1b in c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::operator=(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl> const&) & (this=0x7fffe00012a0, rhs=...)
at /home/jk/workspace/projects/ml/www/pytorch/torch/lib/tmp_install/include/ATen/core/intrusive_ptr.h:244
#5 0x00007fffedad8bbf in at::Tensor::operator=(at::Tensor const&) & (this=0x7fffe00012a0, x=...) at /home/jk/workspace/projects/ml/www/pytorch/torch/lib/tmp_install/include/ATen/core/Tensor.h:105
#6 0x00007fffedc9c469 in torch::autograd::Variable::operator= (this=0x7fffe00012a0) at /home/jk/workspace/projects/ml/www/pytorch/torch/csrc/autograd/variable.h:83
#7 0x00007fffedcbf973 in torch::autograd::PyFunctionPreHook::operator() (this=0x1450cb0, values=std::vector of length 2, capacity 2 = {...}) at torch/csrc/autograd/python_hook.cpp:54
#8 0x00007fffe8c6ca99 in torch::autograd::call_pre_hooks (fn=..., inputs=std::vector of length 2, capacity 2 = {...})
at /home/jk/workspace/projects/ml/www/pytorch/torch/csrc/autograd/engine.cpp:280
#9 0x00007fffe8c6cec6 in torch::autograd::call_function (task=...) at /home/jk/workspace/projects/ml/www/pytorch/torch/csrc/autograd/engine.cpp:350
#10 0x00007fffe8c6d395 in torch::autograd::Engine::evaluate_function (this=0x7fffee58b900 <engine>, task=...) at /home/jk/workspace/projects/ml/www/pytorch/torch/csrc/autograd/engine.cpp:394
#11 0x00007fffe8c6c666 in torch::autograd::Engine::thread_main (this=0x7fffee58b900 <engine>, graph_task=0x0) at /home/jk/workspace/projects/ml/www/pytorch/torch/csrc/autograd/engine.cpp:232
#12 0x00007fffe8c6c4f7 in torch::autograd::Engine::thread_init (this=0x7fffee58b900 <engine>, device=-1) at /home/jk/workspace/projects/ml/www/pytorch/torch/csrc/autograd/engine.cpp:206
#13 0x00007fffedca0630 in torch::autograd::python::PythonEngine::thread_init (this=0x7fffee58b900 <engine>, device=-1) at torch/csrc/autograd/python_engine.cpp:39
#14 0x00007fffe8c8cc02 in std::_Mem_fn_base<void (torch::autograd::Engine::*)(int), true>::operator()<int, void>(torch::autograd::Engine*, int&&) const (this=0x145fa58,
__object=0x7fffee58b900 <engine>) at /usr/include/c++/5/functional:600
#15 0x00007fffe8c8cb7f in std::_Bind_simple<std::_Mem_fn<void (torch::autograd::Engine::*)(int)> (torch::autograd::Engine*, int)>::_M_invoke<0ul, 1ul>(std::_Index_tuple<0ul, 1ul>) (this=0x145fa48)
at /usr/include/c++/5/functional:1531
#16 0x00007fffe8c8ca02 in std::_Bind_simple<std::_Mem_fn<void (torch::autograd::Engine::*)(int)> (torch::autograd::Engine*, int)>::operator()() (this=0x145fa48)
at /usr/include/c++/5/functional:1520
#17 0x00007fffe8c8c952 in std::thread::_Impl<std::_Bind_simple<std::_Mem_fn<void (torch::autograd::Engine::*)(int)> (torch::autograd::Engine*, int)> >::_M_run() (this=0x145fa30)
at /usr/include/c++/5/thread:115
#18 0x00007fffe79a9c80 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#19 0x00007ffff7bc16ba in start_thread (arg=0x7ffff3936700) at pthread_create.c:333
#20 0x00007ffff6da441d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:109
System Info
- PyTorch or Caffe2: PyTorch
- How you installed PyTorch (conda, pip, source): I tested two versions (same errors for both): current master from source, v0.4.1 via pip
- Build command you used (if compiling from source):
NO_CUDA=1 DEBUG=1 python setup.py build develop - OS: Linux Mint 18.3 Sylvia
- PyTorch version: I tested two versions (same errors for both): current master, v0.4.1
- Python version: 3.6.6
- CUDA/cuDNN version: None
- GPU models and configuration: No CUDA
- GCC version (if compiling from source): (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
- CMake version: version 3.12.0
- Versions of any other relevant libraries:
[pip] 18.0
Metadata
Metadata
Assignees
Labels
No labels