Skip to content

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jul 17, 2019

Stack from ghstack:

In the original iteration of the patch, I used lock() everywhere to minimize
the amount of code I have to modify. In this patch, I now eliminate as many
lock()s as I can, when the caller is known to have a strong reference to the
PyFunction, and pass that directly.

Along the way, I also bulk up our error messages for checking the result
of the weak pointer dereference. Some of these cases can be triggered
by zany use of legacy autograd function API; might as well let people know
what they've done wrong.

Signed-off-by: Edward Z. Yang [email protected]

…hecking.

In the original iteration of the patch, I used lock() everywhere to minimize
the amount of code I have to modify.  In this patch, I now eliminate as many
lock()s as I can, when the caller is known to have a strong reference to the
PyFunction, and pass that directly.

Along the way, I also bulk up our error messages for checking the result
of the weak pointer dereference.  Some of these cases can be triggered
by zany use of legacy autograd function API; might as well let people know
what they've done wrong.

Signed-off-by: Edward Z. Yang <[email protected]>
@pytorchbot pytorchbot added module: autograd Related to torch.autograd, and the autograd engine in general module: pybind Related to our Python bindings / interactions with other Python libraries labels Jul 17, 2019
ezyang added a commit that referenced this pull request Jul 17, 2019
…hecking.

In the original iteration of the patch, I used lock() everywhere to minimize
the amount of code I have to modify.  In this patch, I now eliminate as many
lock()s as I can, when the caller is known to have a strong reference to the
PyFunction, and pass that directly.

Along the way, I also bulk up our error messages for checking the result
of the weak pointer dereference.  Some of these cases can be triggered
by zany use of legacy autograd function API; might as well let people know
what they've done wrong.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 66b1aa2
Pull Request resolved: #22998
@ezyang ezyang requested a review from colesbury July 18, 2019 13:36
…ore error checking."

In the original iteration of the patch, I used lock() everywhere to minimize
the amount of code I have to modify.  In this patch, I now eliminate as many
lock()s as I can, when the caller is known to have a strong reference to the
PyFunction, and pass that directly.

Along the way, I also bulk up our error messages for checking the result
of the weak pointer dereference.  Some of these cases can be triggered
by zany use of legacy autograd function API; might as well let people know
what they've done wrong.

Signed-off-by: Edward Z. Yang <[email protected]>
ezyang added a commit that referenced this pull request Jul 18, 2019
…hecking.

In the original iteration of the patch, I used lock() everywhere to minimize
the amount of code I have to modify.  In this patch, I now eliminate as many
lock()s as I can, when the caller is known to have a strong reference to the
PyFunction, and pass that directly.

Along the way, I also bulk up our error messages for checking the result
of the weak pointer dereference.  Some of these cases can be triggered
by zany use of legacy autograd function API; might as well let people know
what they've done wrong.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 296f2e6
Pull Request resolved: #22998
@ezyang ezyang requested a review from apaszke July 18, 2019 19:43

namespace pybind11 { namespace detail {

// handle Python <-> torch::autograd::Function conversions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not actually sure if this is used anywhere, but it doesn't seem to cause anything to fail when I delete it.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accessors from Python need to properly handle the case where cdata is expired. For example, THPFunction_metadata.

Note that this isn't limited to "legacy" autograd functions. You can grab a grad_fn attribute off a variable and have it live longer than the variable and the rest of the autograd graph.

I find this difficult to review separately from #22983, since the earlier PR introduces undesirable behavior which is mostly fixed up here. I would find it easier to review if the two were combined.

@ezyang
Copy link
Contributor Author

ezyang commented Jul 18, 2019

The accessors from Python need to properly handle the case where cdata is expired. For example, THPFunction_metadata.

OK. The best way I could think to do this is to move metadata to live on THPFunction rather than PyFunction. Does this sound reasonable to you? I'll try this change tomorrow.

I find this difficult to review separately from #22983, since the earlier PR introduces undesirable behavior which is mostly fixed up here. I would find it easier to review if the two were combined.

I'm happy to squash. I'll do that tomorrow.

@colesbury
Copy link
Member

OK. The best way I could think to do this is to move metadata to live on THPFunction rather than PyFunction. Does this sound reasonable to you? I'll try this change tomorrow.

That seems fine. I think it would also OK to raise an exception (but not an internal assertion), return an empty value, or return an empty value and warn. Some tests for the behavior would be good too.

@ezyang
Copy link
Contributor Author

ezyang commented Jul 19, 2019

Oh, well, raising an error is a lot easier to do, imma do that first :)

@ezyang ezyang requested a review from albanD July 19, 2019 15:35
@ezyang
Copy link
Contributor Author

ezyang commented Jul 19, 2019

cc'ing @albanD as you may have a better idea what to do about anomaly metadata.

@ezyang
Copy link
Contributor Author

ezyang commented Jul 19, 2019

OK, having done some testing, I feel a lot better about not "fixing" this properly. Take a look at this test program:

import torch
from torch.autograd import Function

class MyFunction(Function):
    @staticmethod
    def forward(ctx, x):
        return x 

    @staticmethod
    def backward(ctx, g):
        return g 

x = torch.zeros(1, requires_grad=True)
y = MyFunction.apply(x)
y.backward()
print(y.grad_fn.metadata)
g = y.grad_fn
del y 
print(g.metadata)

On my branch, you get:

{}
terminate called after throwing an instance of 'c10::Error'
  what():  cdata INTERNAL ASSERT FAILED at ../torch/csrc/autograd/python_function.cpp:982, please report a bug to PyTorch.  (THPFunction_metadata at ../torch/csrc/autograd/python_function.cpp:982)                                                         
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7fbe7788ac65 in /data/users/ezyang/pytorch-tmp/torch/lib/libc10.so)
frame #1: THPFunction_metadata(THPFunction*, void*) + 0x126 (0x7fbe8febacb6 in /data/users/ezyang/pytorch-tmp/torch/lib/libtorch_python.so)
<omitting python frames>
frame #13: __libc_start_main + 0xf5 (0x7fbea48063d5 in /lib64/libc.so.6)                                                      

Aborted (core dumped)

The reason the first call is OK but the second is not is because y keeps PyFunction live, but g only keeps THPFunction alive. Which means if you're a normal person and don't get rid of your variables, you won't run into this bug. And this bug only happens for user defined functions; regular functions are handled correctly (because we apparently bind the Function to Python directly in that case.)

The correct way to fix this is to make grad_fn be an owning reference to PyFunction. But I am lazy and don't want to fix that now.

@ezyang
Copy link
Contributor Author

ezyang commented Jul 19, 2019

As requested by Sam, these PR has been squashed into #22983.

@ezyang ezyang closed this Jul 19, 2019
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/240/head branch October 28, 2019 22:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: autograd Related to torch.autograd, and the autograd engine in general module: pybind Related to our Python bindings / interactions with other Python libraries

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants