-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
torch.compile breaks with Triton built from source (as of Nov 12):
How to reproduce:
- Build Triton from the master branch
- Run torch.compile with a model containing Triton modules, in my case, this script from ao:
The same script works fine with triton==3.1.0:
Time to load model: 2.09 seconds
Compiling Model
/opt/conda/lib/python3.10/contextlib.py:103: FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be remo
ved. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
self.gen = func(*args, **kwds)
Traceback (most recent call last):
File "/root/zmore/ao/torchao/_models/llama/generate.py", line 711, in <module>
main(
File "/root/zmore/ao/torchao/_models/llama/generate.py", line 538, in main
generate(
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/zmore/ao/torchao/_models/llama/generate.py", line 132, in generate
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
File "/root/zmore/ao/torchao/_models/llama/generate.py", line 71, in decode_n_tokens
next_token, next_prob = decode_one_token(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 554, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1401, in __call__
return self._torchdynamo_orig_callable(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 546, in __call__
return _compile(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 979, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 705, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 740, in _compile_inner
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 659, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2909, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in run
while self.step():
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1027, in step
self.dispatch_table[inst.opcode](self, inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3100, in RETURN_VALUE
self._return(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3085, in _return
self.output.compile_subgraph(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1164, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1401, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1448, in call_user_compiler
return self._call_user_compiler(gm)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1497, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1478, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/__init__.py", line 2275, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1398, in compile_fx
return compile_fx(
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1675, in compile_fx
return aot_autograd(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1105, in aot_module_simplified
compiled_fn = dispatch_and_compile()
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1081, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 528, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 780, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 196, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1495, in fw_compiler_base
return _fw_compiler_base(model, example_inputs, is_inference)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1564, in _fw_compiler_base
return inner_compile(
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 572, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 100, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 724, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1479, in load
compiled_graph = compile_fx_fn(
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 635, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 942, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2028, in compile_to_fn
return self.compile_to_module().call
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1950, in compile_to_module
return self._compile_to_module()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1956, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1895, in codegen
self.scheduler.codegen()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 3457, in codegen
return self._codegen()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 3536, in _codegen
self.get_backend(device).codegen_node(node)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 80, in codegen_node
return self._triton_scheduling.codegen_node(node)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1204, in codegen_node
return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1399, in codegen_node_schedule
src_code = kernel.codegen_kernel()
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/triton.py", line 3106, in codegen_kernel
triton_meta["configs"] = [config_of(signature)]
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codegen/triton_utils.py", line 176, in config_of
return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/runtime/hints.py", line 55, in AttrsDescriptorWrapper
res = AttrsDescriptor.from_dict(kwargs)
File "/root/zmore/triton/python/triton/backends/compiler.py", line 167, in from_dict
attrs_descriptor = _descriptor_table[data["cls"]]()
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: 'cls'
Versions
Collecting environment information...
PyTorch version: 2.6.0.dev20241028+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.0
Libc version: glibc-2.35
Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 555.58.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.6.77
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] optree==0.11.0
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241028+cu121
[pip3] torchao==0.7.0+git03c3889c
[pip3] torchaudio==2.3.1
[pip3] torchelastic==0.2.2
[pip3] torchvision==0.19.0
[pip3] triton==3.2.0+git9d6736a5
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bertmaher @int3 @davidberard98 @nmacchioni @chenyang78 @embg @peterbell10 @aakhundov