-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Open
Labels
module: dynamic shapesmodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032oncall: pt2tensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Issue 1: -1 in view [Edit: maybe fixed in #128662]
If you have a subclass which is a .view(-1, ...) of a base subclass tensor and you pass it into a torch-compiled function, we fail when trying to create the fake tensor representing the subclass tensor:
Repro
import torch
from torch.nested._internal.nested_tensor import ViewNestedFromBuffer
def get_inputs():
lengths = torch.randint(2, 1000, (128,), device='cuda')
offsets = torch.zeros(129, device='cuda', dtype=torch.int32)
torch.cumsum(lengths, dim=0, out=offsets[1:])
jagged_input = torch.randn((offsets[-1].item(), 4*32), device='cuda', requires_grad=True)
return jagged_input, offsets
def get_nt():
nt = ViewNestedFromBuffer.apply(*get_inputs())
return nt
def fn(nt):
return nt.sin()
def fn2(nt):
return nt.cos()
# Issue 1
torch.compile(fn)(get_nt().view(-1, -1, 128))
# Issue 2
nt = get_nt()
torch.compile(fn2)(nt.view(nt.shape[0], nt.shape[1], 4, 32))Error
File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 2299, in <lambda>
lambda: tx.fake_mode.from_tensor(
File "/home/dberard/local/pytorch/torch/_subclasses/fake_tensor.py", line 1941, in from_tensor
return self.fake_tensor_converter.from_real_tensor(
File "/home/dberard/local/pytorch/torch/_subclasses/fake_tensor.py", line 332, in from_real_tensor
out = self.meta_converter(
File "/home/dberard/local/pytorch/torch/_subclasses/meta_utils.py", line 1621, in __call__
r = self.meta_tensor(
File "/home/dberard/local/pytorch/torch/_subclasses/meta_utils.py", line 1368, in meta_tensor
r = view_from_base(base, t)
File "/home/dberard/local/pytorch/torch/_subclasses/meta_utils.py", line 1038, in view_from_base
fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn)
File "/home/dberard/local/pytorch/torch/_subclasses/meta_utils.py", line 967, in symint_visitor_fn
symbol = shape_env.create_symbol(s, sym_source)
File "/home/dberard/local/pytorch/torch/fx/experimental/recording.py", line 245, in wrapper
return fn(*args, **kwargs)
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3457, in create_symbol
assert not (positive and val < 0), f"positive set for negative value: {val}"
AssertionError: positive set for negative value: -1
We can apply a patch like below:
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 4ea0db56aae..d1c6b562c94 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -964,7 +964,15 @@ class MetaConverter:
# assumption of it being simplified out will fail and it may be guarded on,
# which will hard error.
sym_source = EphemeralSource("symint_visitor_fn")
- symbol = shape_env.create_symbol(s, sym_source)
+
+ positive = True
+ if isinstance(s, int) and s < 0:
+ # e.g. if you .view(-1, ...)
+ positive = False
+
+ print(f"!! {s}, positive={positive}")
+
+ symbol = shape_env.create_symbol(s, sym_source, positive=positive)
return shape_env.create_symintnode(symbol, hint=s, source=sym_source)
real_to_fake_mapping = {}
Issue 2: Can't simplify some symints from .view-ed subclasses
If you're trying to split a dimension, we'll run into issues simplifying the symints used when replaying the view. For example, if we try to nt.view(nt.shape[0], nt.shape[1], 4, 32) on a nt of shape (128, j0, 128), then we'll hit this issue:
Repro: same code as Issue 1 above, but comment out the line marked as "Issue 1"
Log snippet:
...
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4089, in produce_guards
issue_guard(guard)
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4053, in issue_guard
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint
return self._str(self._print(expr))
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 778, in _print_Relational
self._print(expr.rhs))
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp>
a_str = [self.parenthesize(x, prec, strict=False) for x in a]
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize
return self._print(item)
File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
return printmethod(expr, **kwargs)
File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol
assert self.symbol_to_source.get(expr), (
AssertionError: s7 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s3: ["L['nt']._values.size()[0]", "L['nt']._values.size()[0]"], s0: ["L['nt'].size()[1]"], s6: ["L['nt'].size()[2]", "L['nt']._values.size()[1]"], s1: ["L['nt']._values.size()[0]"], s7: []}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
TwoTensor repro:
# @unittest.expectedFailure
def test_subclass_view_splits_dimension(self):
def f(x):
return x * 2
f_compiled = torch.compile(f, backend="aot_eager", fullgraph=True)
a, b = (torch.randn(4, 15) for _ in range(2))
t = TwoTensor(a, b)
t_view = t.view(t.size(0), 3, 5)
torch._dynamo.mark_dynamic(t_view, 0)
out_ref = f(t_view)
out_test = f_compiled(t_view)
self.assertEqual(out_ref, out_test)Versions
main branch ~june 11, gpu build
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @ezyang @albanD @anijain2305 @chauhang
jbschlosser
Metadata
Metadata
Assignees
Labels
module: dynamic shapesmodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032oncall: pt2tensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module