-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
This error was encountered while trying to implement a version of Autotuner.prune_configs from Triton.
This function was modified from operating on a dict to a list (dict with config keys is also not supported).
A minimal repro would look something like:
est_timing: List[Tuple[triton.runtime.Config, float]]
est_timing = [
(config, perf_model(**named_args, **kwargs, **config.all_kwargs()))
for config in configs
]
configs = sorted(est_timing, key=lambda x: est_timing[1])[:top_k]Here is the complete function which triggered the error (for reproducibility):
def call_prune_configs( # type: ignore[no-untyped-def]
autotuner,
early_config_prune: Callable,
perf_model: Callable,
top_k: float,
is_top_k_float: bool,
configs: List,
named_args: Dict,
kwargs: Dict,
):
if early_config_prune:
configs = early_config_prune(configs, named_args, **kwargs)
if perf_model:
# we assert top_k is a float before calling this
if is_top_k_float and top_k <= 1.0:
top_k = int(len(configs) * top_k)
if len(configs) > top_k:
est_timing = [
(config, perf_model(**named_args, **kwargs, **config.all_kwargs()))
for config in configs
]
configs = sorted(est_timing, key=lambda x: est_timing[1])[:top_k]
return configs
# Called in torch/_higher_order_ops/triton_kernel_wrap.py
pruned_configs = self.call_user_defined_fn(
call_prune_configs,
[
variable,
wrapped_early_configs_prune,
wrapped_perf_model,
wrapped_configs_top_k,
wrapped_is_top_k_float,
wrapped_configs,
named_args,
kwargs,
],
{},
tx,
variable.source,
)Here is a stack trace of the generated bytecode leading up to the error:
"/data/users/ginzburg/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 1023> [BuiltinVariable(sorted), ListVariable(length=2), TupleVariable(length=1)]
V1218 08:22:05.910000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_CONST call_prune_configs.<locals>.<lambda> [BuiltinVariable(sorted), ListVariable(length=2), TupleVariable(length=1), ConstantVariable(code: <code object <lambda> at 0x7f9e3c5fbb50, file "/data/users/ginzburg/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 1023>)]
V1218 08:22:05.910000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE MAKE_FUNCTION 8 [BuiltinVariable(sorted), ListVariable(length=2), TupleVariable(length=1), ConstantVariable(code: <code object <lambda> at 0x7f9e3c5fbb50, file "/data/users/ginzburg/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 1023>), ConstantVariable(str: 'call_prune_configs.<locals>.<lambda>')]
V1218 08:22:05.910000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('key',) [BuiltinVariable(sorted), ListVariable(length=2), NestedUserFunctionVariable()]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 2 [BuiltinVariable(sorted), ListVariable(length=2), NestedUserFunctionVariable(), TupleVariable(length=1)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_DEREF est_timing []
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [ListVariable(length=2)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [ListVariable(length=2), ConstantVariable(int: 1)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TupleVariable(length=2)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_DEREF est_timing []
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [ListVariable(length=2)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [ListVariable(length=2), ConstantVariable(int: 1)]
V1218 08:22:05.912000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TupleVariable(length=2)]
inline_call [('sort with non-constant keys', 1)]Versions
PyTorch version: 2.6.0a0+git28e242f
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-2)
Clang version: 18.1.8 (CentOS 18.1.8-3.el9)
CMake version: version 3.26.4
Libc version: glibc-2.34
Python version: 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0
Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.13.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.16.1
[pip3] onnxscript==0.1.0.dev20240817
[pip3] optree==0.13.0
[pip3] pytorch-triton==3.2.0+git35c6c7c6
[pip3] torch==2.6.0a0+git28e242f
[conda] blas 1.0 mkl
[conda] magma-cuda121 2.6.1 1 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-include 2023.1.0 h06a4308_46344
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.10 py310h5eee18b_0
[conda] mkl_random 1.2.7 py310h1128e8f_0
[conda] numpy 1.26.4 py310h5f9d8c6_0
[conda] numpy-base 1.26.4 py310hb5e798b_0
[conda] optree 0.13.0 pypi_0 pypi
[conda] pytorch-triton 3.2.0+git35c6c7c6 pypi_0 pypi
[conda] torch 2.6.0a0+git28e242f dev_0
[conda] torchfix 0.4.0 pypi_0 pypi
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames