Skip to content

[Dynamo] torch._dynamo.exc.Unsupported: sort with non-constant keys #143505

@SamGinzburg

Description

@SamGinzburg

🐛 Describe the bug

This error was encountered while trying to implement a version of Autotuner.prune_configs from Triton.

This function was modified from operating on a dict to a list (dict with config keys is also not supported).

A minimal repro would look something like:

est_timing: List[Tuple[triton.runtime.Config, float]]
est_timing = [
   (config, perf_model(**named_args, **kwargs, **config.all_kwargs()))
    for config in configs
]
configs = sorted(est_timing, key=lambda x: est_timing[1])[:top_k]

Here is the complete function which triggered the error (for reproducibility):

def call_prune_configs(  # type: ignore[no-untyped-def]
    autotuner,
    early_config_prune: Callable,
    perf_model: Callable,
    top_k: float,
    is_top_k_float: bool,
    configs: List,
    named_args: Dict,
    kwargs: Dict,
):
    if early_config_prune:
        configs = early_config_prune(configs, named_args, **kwargs)

    if perf_model:
        # we assert top_k is a float before calling this
        if is_top_k_float and top_k <= 1.0:
            top_k = int(len(configs) * top_k)
        if len(configs) > top_k:
            est_timing = [
                (config, perf_model(**named_args, **kwargs, **config.all_kwargs()))
                for config in configs
            ]
            configs = sorted(est_timing, key=lambda x: est_timing[1])[:top_k]
    return configs
    
    
# Called in torch/_higher_order_ops/triton_kernel_wrap.py
pruned_configs = self.call_user_defined_fn(
    call_prune_configs,
    [
        variable,
        wrapped_early_configs_prune,
        wrapped_perf_model,
        wrapped_configs_top_k,
        wrapped_is_top_k_float,
        wrapped_configs,
        named_args,
        kwargs,
    ],
    {},
    tx,
    variable.source,
)

Here is a stack trace of the generated bytecode leading up to the error:

"/data/users/ginzburg/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 1023> [BuiltinVariable(sorted), ListVariable(length=2), TupleVariable(length=1)]
V1218 08:22:05.910000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_CONST call_prune_configs.<locals>.<lambda> [BuiltinVariable(sorted), ListVariable(length=2), TupleVariable(length=1), ConstantVariable(code: <code object <lambda> at 0x7f9e3c5fbb50, file "/data/users/ginzburg/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 1023>)]
V1218 08:22:05.910000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE MAKE_FUNCTION 8 [BuiltinVariable(sorted), ListVariable(length=2), TupleVariable(length=1), ConstantVariable(code: <code object <lambda> at 0x7f9e3c5fbb50, file "/data/users/ginzburg/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 1023>), ConstantVariable(str: 'call_prune_configs.<locals>.<lambda>')]
V1218 08:22:05.910000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('key',) [BuiltinVariable(sorted), ListVariable(length=2), NestedUserFunctionVariable()]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 2 [BuiltinVariable(sorted), ListVariable(length=2), NestedUserFunctionVariable(), TupleVariable(length=1)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_DEREF est_timing []
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [ListVariable(length=2)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [ListVariable(length=2), ConstantVariable(int: 1)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TupleVariable(length=2)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_DEREF est_timing []
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [ListVariable(length=2)]
V1218 08:22:05.911000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [ListVariable(length=2), ConstantVariable(int: 1)]
V1218 08:22:05.912000 934875 torch/_dynamo/symbolic_convert.py:956] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TupleVariable(length=2)]
inline_call [('sort with non-constant keys', 1)]

Versions

PyTorch version: 2.6.0a0+git28e242f
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-2)
Clang version: 18.1.8 (CentOS 18.1.8-3.el9)
CMake version: version 3.26.4
Libc version: glibc-2.34

Python version: 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.13.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.16.1
[pip3] onnxscript==0.1.0.dev20240817
[pip3] optree==0.13.0
[pip3] pytorch-triton==3.2.0+git35c6c7c6
[pip3] torch==2.6.0a0+git28e242f
[conda] blas 1.0 mkl
[conda] magma-cuda121 2.6.1 1 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-include 2023.1.0 h06a4308_46344
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.10 py310h5eee18b_0
[conda] mkl_random 1.2.7 py310h1128e8f_0
[conda] numpy 1.26.4 py310h5f9d8c6_0
[conda] numpy-base 1.26.4 py310hb5e798b_0
[conda] optree 0.13.0 pypi_0 pypi
[conda] pytorch-triton 3.2.0+git35c6c7c6 pypi_0 pypi
[conda] torch 2.6.0a0+git28e242f dev_0
[conda] torchfix 0.4.0 pypi_0 pypi

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions