Skip to content

[Nested Tensor / subclasses] view(-1) and splitting dimensions with view() failures #128649

@davidberard98

Description

@davidberard98

🐛 Describe the bug

Issue 1: -1 in view [Edit: maybe fixed in #128662]

If you have a subclass which is a .view(-1, ...) of a base subclass tensor and you pass it into a torch-compiled function, we fail when trying to create the fake tensor representing the subclass tensor:

Repro

import torch
from torch.nested._internal.nested_tensor import ViewNestedFromBuffer

def get_inputs():
    lengths = torch.randint(2, 1000, (128,), device='cuda')
    offsets = torch.zeros(129, device='cuda', dtype=torch.int32)
    torch.cumsum(lengths, dim=0, out=offsets[1:])

    jagged_input = torch.randn((offsets[-1].item(), 4*32), device='cuda', requires_grad=True)
    return jagged_input, offsets

def get_nt():
    nt = ViewNestedFromBuffer.apply(*get_inputs())
    return nt

def fn(nt):
    return nt.sin()

def fn2(nt):
    return nt.cos()

# Issue 1
torch.compile(fn)(get_nt().view(-1, -1, 128))

# Issue 2
nt = get_nt()
torch.compile(fn2)(nt.view(nt.shape[0], nt.shape[1], 4, 32))

Error

  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 2299, in <lambda>
    lambda: tx.fake_mode.from_tensor(
  File "/home/dberard/local/pytorch/torch/_subclasses/fake_tensor.py", line 1941, in from_tensor
    return self.fake_tensor_converter.from_real_tensor(
  File "/home/dberard/local/pytorch/torch/_subclasses/fake_tensor.py", line 332, in from_real_tensor
    out = self.meta_converter(
  File "/home/dberard/local/pytorch/torch/_subclasses/meta_utils.py", line 1621, in __call__
    r = self.meta_tensor(
  File "/home/dberard/local/pytorch/torch/_subclasses/meta_utils.py", line 1368, in meta_tensor
    r = view_from_base(base, t)
  File "/home/dberard/local/pytorch/torch/_subclasses/meta_utils.py", line 1038, in view_from_base
    fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn)
  File "/home/dberard/local/pytorch/torch/_subclasses/meta_utils.py", line 967, in symint_visitor_fn
    symbol = shape_env.create_symbol(s, sym_source)
  File "/home/dberard/local/pytorch/torch/fx/experimental/recording.py", line 245, in wrapper
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3457, in create_symbol
    assert not (positive and val < 0), f"positive set for negative value: {val}"
AssertionError: positive set for negative value: -1

We can apply a patch like below:

diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 4ea0db56aae..d1c6b562c94 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -964,7 +964,15 @@ class MetaConverter:
                 # assumption of it being simplified out will fail and it may be guarded on,
                 # which will hard error.
                 sym_source = EphemeralSource("symint_visitor_fn")
-                symbol = shape_env.create_symbol(s, sym_source)
+
+                positive = True
+                if isinstance(s, int) and s < 0:
+                    # e.g. if you .view(-1, ...)
+                    positive = False
+
+                print(f"!! {s}, positive={positive}")
+
+                symbol = shape_env.create_symbol(s, sym_source, positive=positive)
                 return shape_env.create_symintnode(symbol, hint=s, source=sym_source)

             real_to_fake_mapping = {}

Issue 2: Can't simplify some symints from .view-ed subclasses

If you're trying to split a dimension, we'll run into issues simplifying the symints used when replaying the view. For example, if we try to nt.view(nt.shape[0], nt.shape[1], 4, 32) on a nt of shape (128, j0, 128), then we'll hit this issue:

Repro: same code as Issue 1 above, but comment out the line marked as "Issue 1"

Log snippet:

...
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4089, in produce_guards
    issue_guard(guard)
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4053, in issue_guard
    guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint
    return self._str(self._print(expr))
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 778, in _print_Relational
    self._print(expr.rhs))
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul
    a_str = [self.parenthesize(x, prec, strict=False) for x in a]
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp>
    a_str = [self.parenthesize(x, prec, strict=False) for x in a]
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize
    return self._print(item)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol
    assert self.symbol_to_source.get(expr), (
AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}.  If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665

TwoTensor repro:

    # @unittest.expectedFailure
    def test_subclass_view_splits_dimension(self):
        def f(x):
            return x * 2

        f_compiled = torch.compile(f, backend="aot_eager", fullgraph=True)

        a, b = (torch.randn(4, 15) for _ in range(2))
        t = TwoTensor(a, b)
        
        t_view = t.view(t.size(0), 3, 5)
        torch._dynamo.mark_dynamic(t_view, 0)

        out_ref = f(t_view)
        out_test = f_compiled(t_view)

        self.assertEqual(out_ref, out_test)

Versions

main branch ~june 11, gpu build

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @ezyang @albanD @anijain2305 @chauhang

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamic shapesmodule: nestedtensorNestedTensor tag see issue #25032oncall: pt2tensor subclassRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions