-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
This issue came up with NJT a few weeks ago (I believe @jbschlosser worked around it but we should fix this).
Say you are in the following situation:
(1) you are compiling a model where one of the graph outputs is a tensor subclass
(2) the tensor subclass is authored in a way where the outer subclass shape is different from its inner shape
(3) dynamic shapes are turned on
We will end up blindly generating incorrect shapes for the output subclass tensor today. Why? The problem is that:
(a) at trace-time, we generated a sympy expression for the shape of the output tensor subclass
(b) at runtime, inductor does a bunch of compute to generate the real shape of any inner tensors, given the real tensor input shapes. but we have not actually compiled the integer compute for the outer subclass sizes (which are hidden from inductor)
(c) we end up blatting whatever the sympy expression for the subclass outputs that we traced out at runtime was (I believe this should be a symint, although in the repro below I just get an incorrect integer value)
We have an existing interpreter for sympy expressions (thanks to @ezyang), that we could probably use to turn the sympy expression into an FX graph that we could execute at runtime.
Simple repro below, using a patched version of TwoTensor that directly uses the outer_size argument in tensor_unflatten (instead of dropping it)
import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing
# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
class TwoTensor(torch.Tensor):
@staticmethod
def __new__(cls, a, b, outer_size=None, outer_stride=None):
if outer_size == None:
outer_size = a.shape
if outer_stride == None:
outer_size = a.stride()
shape = outer_size
stride = outer_stride
kwargs = {}
kwargs["strides"] = stride
kwargs["device"] = a.device
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
assert a.shape == b.shape
assert a.stride() == b.stride()
assert a.storage_offset() == b.storage_offset()
return out
def __init__(self, a, b, outer_size=None, outer_stride=None):
self.a = a
self.b = b
def __repr__(self):
a_repr = repr(self.a)
b_repr = repr(self.b)
return f"TwoTensor({a_repr}, {b_repr})"
def __tensor_flatten__(self):
return ["a", "b"], None
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None
a, b = inner_tensors["a"], inner_tensors["b"]
return TwoTensor(a, b, outer_size, outer_stride)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)
kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)
out_a = func(*args_a, **kwargs_a)
out_b = func(*args_b, **kwargs_b)
assert type(out_a) == type(out_b)
out_a_flat, spec = pytree.tree_flatten(out_a)
out_b_flat = pytree.tree_leaves(out_b)
# for aten ops that return non-tensors, just assume that
# our two inner tensors return the same value
out_flat = [
TwoTensor(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
for o_a, o_b in zip(out_a_flat, out_b_flat)
]
out = pytree.tree_unflatten(out_flat, spec)
return return_and_correct_aliasing(func, args, kwargs, out)
@torch.compile(backend='aot_eager')
def f(x):
return x.view(-1) * 2
x1_inner = torch.ones(2, 4)
x1 = TwoTensor(x1_inner, x1_inner.clone())
out1 = f(x1)
x2_inner = torch.ones(3, 5)
x2 = TwoTensor(x2_inner, x2_inner.clone())
out2 = f(x2)
breakpoint()
print(out2.shape)
cc @Chillee @ezyang @zou3519 @albanD @samdow @msaroufim @anijain2305 @chauhang