-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
ezyang's listStuff ezyang doesn't want to loseStuff ezyang doesn't want to losetensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
NOTE - After commenting out gc.collect in the repro, everything works (or just disabling the gc). Also, the freed tensor has to be part of reference cycle for this to fail.
import torch
class IdentityFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
class Node():
pass
a = Node()
b = Node()
# Induce a reference cycle
a.b = b
b.a = a
s = torch.zeros(1,)
s._attrs = {"key": "value"}
# If the tensor is not part of ref cycle, then it is ok
a.s = s # Comment this line and it works fine.
ctx.save_for_backward(s)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output
t = torch.ones(1, device='cpu', requires_grad=True)
y = IdentityFunction.apply(t)
print(y.grad_fn.saved_tensors[0]._attrs)
import gc
print(gc.collect())
print(y.grad_fn.saved_tensors[0]) # This is ok
print(y.grad_fn.saved_tensors[0]._attrs) # This failsOutput
{'key': 'value'}
11
tensor([0.])
Traceback (most recent call last):
File "scratchpad/test.py", line 36, in <module>
print(y.grad_fn.saved_tensors[0]._attrs) # This fails
AttributeError: 'Tensor' object has no attribute '_attrs'
Maybe related : #47117
cc: @ezyang @albanD (potentially related to PyObject preservation)
Versions
main 462b727
Metadata
Metadata
Assignees
Labels
ezyang's listStuff ezyang doesn't want to loseStuff ezyang doesn't want to losetensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module