Skip to content

Device transfer for NJT within torch.compile allocates a new nested int #137275

@jbschlosser

Description

@jbschlosser

Note that this differs from eager behavior for njt.to(device), where the nested int from the input shape is purposefully includes in the output shape, as expected.

Repro:

import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
# fails; torch.Size([3, j1, 5]) for out and torch.Size([3, j2, 5]) for out_compile
assert out.shape == out_compile.shape

cc @cpuhrsch @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @ezyang @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions