Skip to content

[triton 3.2] std::bad_alloc: torch.compile breaks with Triton built from source #140423

@mobicham

Description

@mobicham

🐛 Describe the bug

torch.compile breaks with Triton built from source (as of Nov 12):

How to reproduce:

  1. Build Triton from the master branch
  2. Run torch.compile with a model containing Triton modules, in my case, this script from ao:

The same script works fine with triton==3.1.0:

 Time to load model: 2.09 seconds
Compiling Model
/opt/conda/lib/python3.10/contextlib.py:103: FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be remo
ved. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
  self.gen = func(*args, **kwds)
Traceback (most recent call last):
  File "/root/zmore/ao/torchao/_models/llama/generate.py", line 711, in <module>
    main(
  File "/root/zmore/ao/torchao/_models/llama/generate.py", line 538, in main
    generate(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/zmore/ao/torchao/_models/llama/generate.py", line 132, in generate
    generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
  File "/root/zmore/ao/torchao/_models/llama/generate.py", line 71, in decode_n_tokens
    next_token, next_prob = decode_one_token(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 554, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1401, in __call__
    return self._torchdynamo_orig_callable(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 546, in __call__
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 979, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 705, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 740, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 659, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2909, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in run
    while self.step():
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1027, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3100, in RETURN_VALUE
    self._return(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3085, in _return
    self.output.compile_subgraph(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1164, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1401, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1448, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1497, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1478, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 2275, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1398, in compile_fx
    return compile_fx(
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1675, in compile_fx
    return aot_autograd(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1105, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1081, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 528, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 780, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 196, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1495, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1564, in _fw_compiler_base
    return inner_compile(
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 572, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 100, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 724, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1479, in load
    compiled_graph = compile_fx_fn(
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 635, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 942, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2028, in compile_to_fn
    return self.compile_to_module().call
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1950, in compile_to_module
    return self._compile_to_module()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1956, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1895, in codegen
    self.scheduler.codegen()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 3457, in codegen
    return self._codegen()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 3536, in _codegen
    self.get_backend(device).codegen_node(node)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 80, in codegen_node
    return self._triton_scheduling.codegen_node(node)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1204, in codegen_node
    return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1399, in codegen_node_schedule
    src_code = kernel.codegen_kernel()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/triton.py", line 3106, in codegen_kernel
    triton_meta["configs"] = [config_of(signature)]
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/triton_utils.py", line 176, in config_of
    return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/runtime/hints.py", line 55, in AttrsDescriptorWrapper
    res = AttrsDescriptor.from_dict(kwargs)
  File "/root/zmore/triton/python/triton/backends/compiler.py", line 167, in from_dict
    attrs_descriptor = _descriptor_table[data["cls"]]()
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: 'cls'

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20241028+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.0
Libc version: glibc-2.35

Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 555.58.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.6.77
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] optree==0.11.0
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241028+cu121
[pip3] torchao==0.7.0+git03c3889c
[pip3] torchaudio==2.3.1
[pip3] torchelastic==0.2.2
[pip3] torchvision==0.19.0
[pip3] triton==3.2.0+git9d6736a5

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bertmaher @int3 @davidberard98 @nmacchioni @chenyang78 @embg @peterbell10 @aakhundov

Metadata

Metadata

Assignees

No one assigned

    Labels

    high priorityoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton Issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions