Skip to content

int_mm seems broken due to Triton upgrade #144705

@cpuhrsch

Description

@cpuhrsch

🐛 Describe the bug

import torch
from torch._higher_order_ops.out_dtype import out_dtype


def quantized_matmul(x_vals_int8, x_scales, w_vals_int8):
    return out_dtype(torch.ops.aten.mm.default, torch.int32, x_vals_int8, w_vals_int8) * x_scales


x_vals_int8 = torch.randn(65536, 144).to(dtype=torch.int8).cuda()
x_scales = torch.randn(65536, 1).to(dtype=torch.float32).cuda()
w_vals_int8 = torch.randn(432, 144).to(dtype=torch.int8).cuda().t()

qcm = torch.compile(quantized_matmul, mode='max-autotune-no-cudagraphs')

qcm(x_vals_int8, x_scales, w_vals_int8)

produces

python: /root/.triton/llvm/llvm-86b69c31-almalinux-x64/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::FloatAttr, From = mlir::Attribute]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
Aborted (core dumped)

This works on nightly20241126py312 with pytorch-triton 3.1.0+cf34004b8a. Can do more fine-grained bisection if needed.

Versions

ersions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-triton==3.2.0+git0d4682f0
[pip3] torch==2.7.0.dev20250113+cu124
[pip3] torchaudio==2.6.0.dev20250113+cu124
[pip3] torchvision==0.22.0.dev20250113+cu124
[conda] numpy                     2.1.2                    pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pytorch-triton            3.2.0+git0d4682f0          pypi_0    pypi
[conda] torch                     2.7.0.dev20250113+cu124          pypi_0    pypi
[conda] torchaudio                2.6.0.dev20250113+cu124          pypi_0    pypi
[conda] torchvision               0.22.0.dev20250113+cu124          pypi_0    pypi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @bertmaher @int3 @davidberard98 @nmacchioni @embg @peterbell10

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions