Skip to content

tl.constexpr inputs to user-defined triton kernels should not be dynamic #136504

@davidberard98

Description

@davidberard98

🐛 Describe the bug

Right now we get dynamic inputs to tl.constexpr inputs, which fail - repro below.

repro:

import triton
import triton.language as tl

@triton.jit
def triton_(x_ptr, y_ptr, NUMEL: tl.constexpr, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = BLOCK_SIZE*pid + tl.arange(0, BLOCK_SIZE)
    mask = offsets < NUMEL

    data = tl.load(x_ptr + offsets, mask)
    result = data * data

    tl.store(y_ptr + offsets, result, mask)


def fn(x):
    y = torch.empty_like(x)
    BLOCK_SIZE = 256
    numel = x.numel()
    grid = (triton.cdiv(numel, BLOCK_SIZE),)
    triton_[grid](x, y, numel, BLOCK_SIZE)
    return y

x1 = torch.randn(256*2 + 5, device="cuda")
x2 = torch.randn(256*3 + 7, device="cuda")

fn(x1)
fn(x2)

fn_c = torch.compile(fn)

fn_c(x1)
fn_c(x2)
...
  File "/home/dberard/local/pytorch/torch/_inductor/compile_fx.py", line 574, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
  File "/home/dberard/local/pytorch/torch/_inductor/compile_fx.py", line 882, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/home/dberard/local/pytorch/torch/_inductor/graph.py", line 1948, in compile_to_fn
    return self.compile_to_module().call
  File "/home/dberard/local/pytorch/torch/_inductor/graph.py", line 1874, in compile_to_module
    return self._compile_to_module()
  File "/home/dberard/local/pytorch/torch/_inductor/graph.py", line 1902, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/home/dberard/local/pytorch/torch/_inductor/codecache.py", line 2949, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/home/dberard/local/pytorch/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_dberard/22/c22qdhxpnvhexva72nxhp7p2dhuh5k3bi7mxu6jvoglb4iowjpu7.py", line 34, in <module>
    triton__0 = async_compile.triton('triton_', '''
  File "/home/dberard/local/pytorch/torch/_inductor/async_compile.py", line 198, in triton
    kernel = TritonCodeCache.load(kernel_name, source_code)
  File "/home/dberard/local/pytorch/torch/_inductor/codecache.py", line 2999, in load
    return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
  File "/home/dberard/local/pytorch/torch/_inductor/codecache.py", line 2936, in load
    return cls.load_by_key_path(key, path, linemap, attrs)
  File "/home/dberard/local/pytorch/torch/_inductor/codecache.py", line 2949, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/home/dberard/local/pytorch/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_dberard/na/cnax67r5zmslz7bvdfizteaepj7fajpjallb3bu2gyetjcdqtbzj.py", line 14, in <module>
    triton_meta={'signature': {0: '*fp32', 1: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=132, warp_size=32), 'constants': {2: s0, 3: 256}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NameError: name 's0' is not defined

Versions

viable/strict sep 23, H100

cc @ezyang @chauhang @penguinwu @oulgen @aakhundov

Metadata

Metadata

Assignees

Labels

module: dynamic shapesmodule: user tritonrelated to ability to directly torch.compile triton kernelsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions