Skip to content

Symbolic shape fails symbol_to_source when caching is enabled #127970

@davidberard98

Description

@davidberard98

🐛 Describe the bug

There's two issues here:

  • First, NT/PT2 stack should "lift" these symints as inputs to the fx graph (@bdhirsh )
  • Next, this issue only repros w/ caching - we should make it so that caching path doesn't require this, since the non-caching path also doesn't need it (@masnesral / @oulgen)

Repro:

import torch
import torch._inductor.config

# or as an environment variable, TORCHINDUCTOR_FX_GRAPH_CACHE=1
torch._inductor.config.fx_graph_cache = True

def gen_nt(r):
    values = torch.randn(r, 16)
    offsets = torch.tensor([0, 2, 3, 6, 13, r])
    return torch.nested.nested_tensor_from_jagged(values, offsets)

def fn(nt):
    if nt.values().size(0) % 16 == 0:
        return nt.sin()
    return nt.cos()

torch.compile(fn)(gen_nt(19))
torch.compile(fn)(gen_nt(20))

Versions

pytorch main branch ~jul 3

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @ezyang @msaroufim @albanD @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions