Skip to content

NJT + compile can't handle nested_ints created in the middle of the graph #138496

@bdhirsh

Description

@bdhirsh

repro:

import torch
from torch.nested._internal.nested_tensor import jagged_from_list

def get_jagged_tensor(nested_size, offsets):
    D = nested_size[1]
    out = []
    for s in nested_size[0]:
        out.append(torch.randn(s, D, requires_grad=False, dtype=torch.float64))
    return jagged_from_list(out, offsets)

@torch.compile(backend="aot_eager")
def f(nt):
    nested_size = ((2, 3, 4), 5)
    offsets = None
    nt2, _ = get_jagged_tensor(nested_size, offsets)
    nt3 = torch.cat([nt2, nt], dim=-1)
    return nt3.sin() * nt3.size(1)

nested_size = ((2, 3, 4), 5)
offsets = None
nt, _ = get_jagged_tensor(nested_size, offsets)
out = f(nt)

gives the (cursed) "untracked proxy" error:

RuntimeError: s0 (140510816656096)is not tracked with proxy for <torch.fx.experimental.proxy_tensor.PythonKeyTracer object at 0x7fcb3a6b7760>

Interestingly, in this particular case you can work around the error by flipping the order of the arguments to cat to be torch.cat([nt, nt2], ...). It looks like that is because the NJT lowering for aten.cat uses the nested_int from the first argument, and discards all other nested ints.

I noticed this while trying to understand why Guilherme's recent dynamic shapes changes didn't fix the NJT issues from here. The original suspicion was that Guilherme's changes would mean that we can delete maybe_enable_thunkify (see link), since we can now handle the case where we have subclass graph inputs/outputs that have dynamic shapes. It looks like this is not sufficient to handle nested_int creation though.

It looks like the underlying problem is that:

(1) the function above generates a new nested_int inside of the compiled region, and returns an NJT with that nested_int as part of its shape.

(2) compile doesn't really seem to be able to deal with fresh nested_ints being constructed mid-graph. We need to generate a "subclass-free" graph to compile, and (somehow) one of the outputs to this graph needs to be the newly-constructed nested_int. Alternatively, we need to arrange for one of the outputs of the graph to be an offsets tensor, so NJT can generate the nested_int on the fly at runtime.

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @ezyang @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions