Skip to content

Commit c0deec1

Browse files
albanDpytorchmergebot
authored andcommitted
Fix resurrection logic to trigger early enough (#137267)
Fixes #136358 The bug here is that the Tensor object is actually 2 classes: `Tensor` from `_tensor.py` and `TensorBase` from c++. Before this PR, they have the following gc methods: Tensor: - tp_clear subtype_clear - tp_traverse THPVariable_subclass_traverse - tp_dealloc THPVariable_subclass_dealloc TensorBase: - tp_clear THPVariable_clear - tp_traverse THPFunction_traverse (fake function that is just an error) - tp_dealloc object_dealloc The problem is that when clear is called on the Tensor, subtype_clear is going to clear the things owned by the "Tensor" type, in particular, its `__dict__` attribute, before delegating to the TensorBase clear where we detect that resurrection needs to happen and skip it. Pull Request resolved: #137267 Approved by: https://github.com/ezyang, https://github.com/kshitij12345
1 parent bd48933 commit c0deec1

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

test/test_torch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10036,6 +10036,30 @@ def __del__(self):
1003610036
self.assertEqual(MyStorage.finalized_count, 1)
1003710037
self.assertTrue(m[0])
1003810038

10039+
def test_tensor_ressurecting_clear(self):
10040+
# Regression test for https://github.com/pytorch/pytorch/issues/136358
10041+
# A Tensor with custom __dict__
10042+
# Autograd here is for the c++ reference later
10043+
t = torch.rand(2, requires_grad=True).clone()
10044+
t.foo = 2
10045+
10046+
# that is part of a cycle
10047+
l = []
10048+
l.append(l)
10049+
l.append(t)
10050+
10051+
# Keep the Tensor alive from c++
10052+
# Using autograd graph here (any other mean would work)
10053+
t2 = t ** 2
10054+
self.assertIs(t2.grad_fn._saved_self, t)
10055+
10056+
# Clear all python references and trigger the gc
10057+
del t, l
10058+
gc.collect()
10059+
10060+
# We used to loose the dict!
10061+
self.assertTrue(hasattr(t2.grad_fn._saved_self, "foo"))
10062+
1003910063
def test_tensor_slot_dealloc(self):
1004010064

1004110065
class SlotTensor1(torch.Tensor):

torch/csrc/autograd/python_variable.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
409409
return true;
410410
}
411411

412-
static int THPVariable_clear(THPVariable* self) {
412+
static int THPVariable_subclass_clear(THPVariable* self) {
413413
// Is it OK for an object to still be live after running
414414
// tp_clear? Yes. When Python is breaking reference cycles, it can't assume
415415
// that an object will dealloc after it's cleared. The source code explicitly
@@ -465,7 +465,7 @@ static int THPVariable_clear(THPVariable* self) {
465465
// !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
466466
// ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
467467
// please report a bug to PyTorch. Exception raised from
468-
// THPVariable_clear at
468+
// THPVariable_subclass_clear at
469469
// ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
470470
// first): frame #0: c10::Error::Error(c10::SourceLocation,
471471
// std::__1::basic_string<char, std::__1::char_traits<char>,
@@ -475,7 +475,7 @@ static int THPVariable_clear(THPVariable* self) {
475475
// c10::detail::torchInternalAssertFail(char const*, char const*,
476476
// unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
477477
// (0x1141e3f89 in libtorch_python.dylib) frame #3:
478-
// THPVariable_clear(THPVariable*) + 412 (0x1148a547c in
478+
// THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in
479479
// libtorch_python.dylib) frame #4:
480480
// THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
481481
// libtorch_python.dylib) frame #5: (anonymous
@@ -507,9 +507,15 @@ static int THPVariable_clear(THPVariable* self) {
507507
return 0;
508508
}
509509

510-
int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
510+
int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) {
511511
TORCH_INTERNAL_ASSERT(
512-
false, "Tensor tp_traverse function was not overriden properly");
512+
false, "TensorBase tp_traverse function was not overriden properly");
513+
return 0;
514+
}
515+
516+
int THPFake_clear(THPVariable* self) {
517+
TORCH_INTERNAL_ASSERT(
518+
false, "TensorBase tp_clear function was not overriden properly");
513519
return 0;
514520
}
515521

@@ -1850,8 +1856,8 @@ PyTypeObject THPVariableType = {
18501856
Py_TPFLAGS_HAVE_GC, /* tp_flags */
18511857
nullptr, /* tp_doc */
18521858
// Also set by metaclass
1853-
(traverseproc)THPFunction_traverse, /* tp_traverse */
1854-
(inquiry)THPVariable_clear, /* tp_clear */
1859+
(traverseproc)THPFake_traverse, /* tp_traverse */
1860+
(inquiry)THPFake_clear, /* tp_clear */
18551861
nullptr, /* tp_richcompare */
18561862
0, /* tp_weaklistoffset */
18571863
nullptr, /* tp_iter */
@@ -1984,7 +1990,7 @@ void THPVariable_subclass_dealloc(PyObject* self) {
19841990
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
19851991

19861992
// Finally clear out the base THPVariable
1987-
THPVariable_clear((THPVariable*)self);
1993+
THPVariable_subclass_clear((THPVariable*)self);
19881994
((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
19891995
Py_TYPE(self)->tp_free(self);
19901996

@@ -2277,9 +2283,17 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
22772283
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
22782284
return -1;
22792285
}
2286+
// It is important for all three of these to be overriden correctly for the
2287+
// resurrection checks to properly happen. In particular, an older version
2288+
// was not overriding tp_clear here. This lead to the default subtype_clear
2289+
// running on the Tensor object (as only TensorBase tp_clear was custom),
2290+
// clearing the __dict__ field, before the TensorBase custom clear was called
2291+
// and would properly detect the resurrect.
2292+
// See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior
22802293
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
22812294
((PyTypeObject*)cls)->tp_traverse =
22822295
(traverseproc)THPVariable_subclass_traverse;
2296+
((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear;
22832297

22842298
// Don't do anything for the base Tensor class
22852299
if (!THPVariableClass) {

0 commit comments

Comments
 (0)