Skip to content

torch.compile + dynamic shapes + tensor subclass graph output is broken #124619

@bdhirsh

Description

@bdhirsh

This issue came up with NJT a few weeks ago (I believe @jbschlosser worked around it but we should fix this).

Say you are in the following situation:
(1) you are compiling a model where one of the graph outputs is a tensor subclass
(2) the tensor subclass is authored in a way where the outer subclass shape is different from its inner shape
(3) dynamic shapes are turned on

We will end up blindly generating incorrect shapes for the output subclass tensor today. Why? The problem is that:

(a) at trace-time, we generated a sympy expression for the shape of the output tensor subclass
(b) at runtime, inductor does a bunch of compute to generate the real shape of any inner tensors, given the real tensor input shapes. but we have not actually compiled the integer compute for the outer subclass sizes (which are hidden from inductor)
(c) we end up blatting whatever the sympy expression for the subclass outputs that we traced out at runtime was (I believe this should be a symint, although in the repro below I just get an incorrect integer value)

We have an existing interpreter for sympy expressions (thanks to @ezyang), that we could probably use to turn the sympy expression into an FX graph that we could execute at runtime.

Simple repro below, using a patched version of TwoTensor that directly uses the outer_size argument in tensor_unflatten (instead of dropping it)

import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing


# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
class TwoTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, a, b, outer_size=None, outer_stride=None):
        if outer_size == None:
            outer_size = a.shape
        if outer_stride == None:
            outer_size = a.stride()
        shape = outer_size
        stride = outer_stride
        kwargs = {}
        kwargs["strides"] = stride
        kwargs["device"] = a.device
        kwargs["layout"] = a.layout
        kwargs["requires_grad"] = a.requires_grad
        kwargs["dtype"] = a.dtype
        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

        assert a.shape == b.shape
        assert a.stride() == b.stride()
        assert a.storage_offset() == b.storage_offset()
        return out

    def __init__(self, a, b, outer_size=None, outer_stride=None):
        self.a = a
        self.b = b

    def __repr__(self):
        a_repr = repr(self.a)
        b_repr = repr(self.b)
        return f"TwoTensor({a_repr}, {b_repr})"

    def __tensor_flatten__(self):
        return ["a", "b"], None

    @staticmethod
    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
        assert meta is None
        a, b = inner_tensors["a"], inner_tensors["b"]
        return TwoTensor(a, b, outer_size, outer_stride)

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        if kwargs is None:
            kwargs = {}
        args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
        args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)

        kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
        kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)

        out_a = func(*args_a, **kwargs_a)
        out_b = func(*args_b, **kwargs_b)
        assert type(out_a) == type(out_b)
        out_a_flat, spec = pytree.tree_flatten(out_a)
        out_b_flat = pytree.tree_leaves(out_b)
        # for aten ops that return non-tensors, just assume that
        # our two inner tensors return the same value
        out_flat = [
            TwoTensor(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
            for o_a, o_b in zip(out_a_flat, out_b_flat)
        ]
        out = pytree.tree_unflatten(out_flat, spec)
        return return_and_correct_aliasing(func, args, kwargs, out)


@torch.compile(backend='aot_eager')
def f(x):
    return x.view(-1) * 2

x1_inner = torch.ones(2, 4)
x1 = TwoTensor(x1_inner, x1_inner.clone())
out1 = f(x1)

x2_inner = torch.ones(3, 5)
x2 = TwoTensor(x2_inner, x2_inner.clone())
out2 = f(x2)
breakpoint()
print(out2.shape)

cc @Chillee @ezyang @zou3519 @albanD @samdow @msaroufim @anijain2305 @chauhang

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions