-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
repro:
import torch
from torch.nested._internal.nested_tensor import jagged_from_list
def get_jagged_tensor(nested_size, offsets):
D = nested_size[1]
out = []
for s in nested_size[0]:
out.append(torch.randn(s, D, requires_grad=False, dtype=torch.float64))
return jagged_from_list(out, offsets)
@torch.compile(backend="aot_eager")
def f(nt):
nested_size = ((2, 3, 4), 5)
offsets = None
nt2, _ = get_jagged_tensor(nested_size, offsets)
nt3 = torch.cat([nt2, nt], dim=-1)
return nt3.sin() * nt3.size(1)
nested_size = ((2, 3, 4), 5)
offsets = None
nt, _ = get_jagged_tensor(nested_size, offsets)
out = f(nt)
gives the (cursed) "untracked proxy" error:
RuntimeError: s0 (140510816656096)is not tracked with proxy for <torch.fx.experimental.proxy_tensor.PythonKeyTracer object at 0x7fcb3a6b7760>
Interestingly, in this particular case you can work around the error by flipping the order of the arguments to cat to be torch.cat([nt, nt2], ...). It looks like that is because the NJT lowering for aten.cat uses the nested_int from the first argument, and discards all other nested ints.
I noticed this while trying to understand why Guilherme's recent dynamic shapes changes didn't fix the NJT issues from here. The original suspicion was that Guilherme's changes would mean that we can delete maybe_enable_thunkify (see link), since we can now handle the case where we have subclass graph inputs/outputs that have dynamic shapes. It looks like this is not sufficient to handle nested_int creation though.
It looks like the underlying problem is that:
(1) the function above generates a new nested_int inside of the compiled region, and returns an NJT with that nested_int as part of its shape.
(2) compile doesn't really seem to be able to deal with fresh nested_ints being constructed mid-graph. We need to generate a "subclass-free" graph to compile, and (somehow) one of the outputs to this graph needs to be the newly-constructed nested_int. Alternatively, we need to arrange for one of the outputs of the graph to be an offsets tensor, so NJT can generate the nested_int on the fly at runtime.
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @ezyang @chauhang @penguinwu