-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Open
Labels
module: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Note that this differs from eager behavior for njt.to(device), where the nested int from the input shape is purposefully includes in the output shape, as expected.
Repro:
import torch
def f(nt):
return nt.to(device="cpu")
compiled_f = torch.compile(f)
nt = torch.nested.nested_tensor([
torch.randn(2, 5),
torch.randn(3, 5),
torch.randn(4, 5),
], layout=torch.jagged, device="cuda")
out = f(nt)
out_compile = compiled_f(nt)
# fails; torch.Size([3, j1, 5]) for out and torch.Size([3, j2, 5]) for out_compile
assert out.shape == out_compile.shapecc @cpuhrsch @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @ezyang @chauhang @penguinwu
Metadata
Metadata
Assignees
Labels
module: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module