Skip to content

Tensor subclass object's __dict__ is cleared when gc.collect is called #136358

@kshitij12345

Description

@kshitij12345

🐛 Describe the bug

NOTE - After commenting out gc.collect in the repro, everything works (or just disabling the gc). Also, the freed tensor has to be part of reference cycle for this to fail.

import torch

class IdentityFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):

        class Node():
            pass

        a = Node()
        b = Node()
        # Induce a reference cycle
        a.b = b
        b.a = a

        s = torch.zeros(1,)
        s._attrs = {"key": "value"}
        # If the tensor is not part of ref cycle, then it is ok
        a.s = s  # Comment this line and it works fine.
        ctx.save_for_backward(s)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

t = torch.ones(1, device='cpu', requires_grad=True)

y = IdentityFunction.apply(t)

print(y.grad_fn.saved_tensors[0]._attrs)
import gc

print(gc.collect())
print(y.grad_fn.saved_tensors[0])  # This is ok
print(y.grad_fn.saved_tensors[0]._attrs)  # This fails

Output

{'key': 'value'}
11
tensor([0.])
Traceback (most recent call last):
  File "scratchpad/test.py", line 36, in <module>
    print(y.grad_fn.saved_tensors[0]._attrs)  # This fails
AttributeError: 'Tensor' object has no attribute '_attrs'

Maybe related : #47117

cc: @ezyang @albanD (potentially related to PyObject preservation)

Versions

main 462b727

cc @ezyang @albanD

Metadata

Metadata

Assignees

No one assigned

    Labels

    ezyang's listStuff ezyang doesn't want to losetensor subclassRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions