Skip to content

ROCm MI300X sum() way slower than H100 #132964

@functionstackx

Description

@functionstackx

🐛 Describe the bug

even tho on Tensor.copy_ we see major improvements on BW on MI300X compared to H100. On a similar memory BW bound op like sum(), we were able to achieve a read bandwidth of 3136GByte/s on H100 SXM while only 1757.8GByte/s on MI300X.

Reprod

import torch
from triton.testing import do_bench

x = torch.randn(2**30, device='cuda')

ms = do_bench(lambda: x.sum(dim=-1))

bandwidth_gbyte = x.numel() * x.dtype.itemsize / (10**9)

time_s = ms / 1000

bw_per_second = bandwidth_gbyte / time_s

print(bw_per_second)

Versions

MI300X

latest rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 image

sudo docker run --privileged --network=host --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 192G -v $(pwd):/var/lib/jenkins -it rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

H100 SXM

latest nvcr.io/nvidia/pytorch:24.07-py3 image

sudo docker run -it --ipc=host --ulimit memlock=-1 --ulimit stack=6710886 --privileged --gpus all -v $(pwd):/workspace nvcr.io/nvidia/pytorch:24.07-py3

cc @msaroufim @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang

Metadata

Metadata

Assignees

Labels

module: performanceIssues related to performance, either of kernel code or framework gluemodule: rocmAMD GPU support for PytorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions