Skip to content

torch.compile bug when using resize #155209

@remi-or

Description

@remi-or

🐛 Describe the bug

On AMD GPUs, if there is a module containing a call to torchvision.transforms.v2.functionnal.resize on a torch.uint8 input, the output changes when using torch.compile on the module.

Running on a MI300 GPU with torch 2.7.0a0+git3a58512 and torchvision 0.19.1a0+6194369 (torch 2.7.0 failed to even compile so I had to pick the right commit), here is a minimal reproducible example:

import torch
from torch import nn
from torchvision.transforms.v2 import functional as F

class Resizer(nn.Module):
    def __init__(self, height: int, width: int):
        super().__init__()
        self.height = height
        self.width = width

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.resize(x, (self.height, self.width), interpolation=F.InterpolationMode.BICUBIC, antialias=True)

if __name__ == "__main__":

    x = torch.randn((3, 41, 70)).mul(256).to(torch.uint8).to("cuda:0")
    resizer = Resizer(384, 384)
    y = resizer(x)
    resizer = torch.compile(resizer)
    y_compiled = resizer(x)
    print(f"Max difference: {y.float().sub(y_compiled.float()).abs().max().item()}")

# stdout: Max difference: 255.0

I think it has something to do with the graph made by torch.compile, here is a gist where the issue is fixed: https://gist.github.com/remi-or/eb8936ca093d54c186fb5b67f15334eb
If I use torch.clamp instead of the two masked_fll the issue is still there, might be related.

Versions

Collecting environment information...
PyTorch version: 2.7.0a0+git3a58512
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.3.42133-1b9c17779

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 18.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-6.3.1 24491 1e0fda770a2079fbd71e4b70974d74f62fd3af10)
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.11.11 (main, Feb 12 2025, 14:51:05) [Clang 19.1.6 ] (64-bit runtime)
Python platform: Linux-5.15.0-116-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.3.42133
MIOpen runtime version: 3.3.0
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 192
On-line CPU(s) list: 0-191
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 1
Core(s) per socket: 96
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3707.8120
CPU min MHz: 1500.0000
BogoMIPS: 4792.43
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization: AMD-V
L1d cache: 6 MiB (192 instances)
L1i cache: 6 MiB (192 instances)
L2 cache: 192 MiB (192 instances)
L3 cache: 768 MiB (24 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-95
NUMA node1 CPU(s): 96-191
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] mypy-protobuf==3.6.0
[pip3] numpy==1.26.4
[pip3] torch==2.7.0a0+git3a58512
[pip3] torchvision==0.19.1a0+6194369
[pip3] triton==3.2.0+gite5be006a
[conda] Could not collect

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

Labels

module: inductormodule: rocmAMD GPU support for Pytorchoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions