-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: dynamic shapesmodule: inductormodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032oncall: pt2tensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
There's two issues here:
- First, NT/PT2 stack should "lift" these symints as inputs to the fx graph (@bdhirsh )
- Next, this issue only repros w/ caching - we should make it so that caching path doesn't require this, since the non-caching path also doesn't need it (@masnesral / @oulgen)
Repro:
import torch
import torch._inductor.config
# or as an environment variable, TORCHINDUCTOR_FX_GRAPH_CACHE=1
torch._inductor.config.fx_graph_cache = True
def gen_nt(r):
values = torch.randn(r, 16)
offsets = torch.tensor([0, 2, 3, 6, 13, r])
return torch.nested.nested_tensor_from_jagged(values, offsets)
def fn(nt):
if nt.values().size(0) % 16 == 0:
return nt.sin()
return nt.cos()
torch.compile(fn)(gen_nt(19))
torch.compile(fn)(gen_nt(20))Versions
pytorch main branch ~jul 3
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @ezyang @msaroufim @albanD @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
Metadata
Metadata
Assignees
Labels
module: dynamic shapesmodule: inductormodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032oncall: pt2tensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module