Skip to content

Distributed Tensor raises error with torch 2.5 #138742

@aturker1

Description

@aturker1

I'm encountering an IndexError when trying to perform a matrix multiplication between distributed tensors after upgrading to PyTorch 2.5.0. This same code was working perfectly in PyTorch 2.4.0.

Repro:

import torch
from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor, Replicate

if __name__ == "__main__":
    mesh = init_device_mesh("cpu", (4,1))
    left = torch.randn(256, 128)
    right = torch.randn(128,256)
    
    left_d = distribute_tensor(left, mesh, [Shard(dim=0), Replicate()])
    right_d = distribute_tensor(right, mesh, [Replicate(), Replicate()])

    res = left_d @ right_d
    print(res)
    

On terminal:

torchrun --standalone --nnodes=1 --nproc-per-node=4 file.py

Throws:

[rank0]: Traceback (most recent call last):
[rank0]: File "/Users/synnada/projects/composite-ml-new/file.py", line 12, in
[rank0]: res = left_d @ right_d
[rank0]: ~~~~~~~^~~~~~~~~
[rank0]: File "/Users/synnada/anaconda3/envs/py12/lib/python3.12/site-packages/torch/_compile.py", line 32, in inner
[rank0]: return disable_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/Users/synnada/anaconda3/envs/py12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/Users/synnada/anaconda3/envs/py12/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 340, in torch_dispatch
[rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/Users/synnada/anaconda3/envs/py12/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 181, in dispatch
[rank0]: self.redistribute_local_args(
[rank0]: File "/Users/synnada/anaconda3/envs/py12/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 317, in redistribute_local_args
[rank0]: resharded_local_tensor = redistribute_local_tensor(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/Users/synnada/anaconda3/envs/py12/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py", line 177, in redistribute_local_tensor
[rank0]: transform_infos = _gen_transform_infos(current_spec, target_spec)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/Users/synnada/anaconda3/envs/py12/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py", line 100, in _gen_transform_infos
[rank0]: target = target_placements[mesh_dim]
[rank0]: ~~~~~~~~~~~~~~~~~^^^^^^^^^^
[rank0]: IndexError: list index out of range

Versions

PyTorch version: 2.5.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.0 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: version 3.28.1
Libc version: N/A

Python version: 3.12.0 | packaged by Anaconda, Inc. | (main, Oct 2 2023, 12:22:05) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Pro

Versions of relevant libraries:
[pip3] mypy==1.12.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.2
[pip3] torch==2.5.0
[pip3] torchaudio==2.3.1
[pip3] torchvision==0.18.1
[conda] numpy 2.1.2 pypi_0 pypi
[conda] torch 2.5.0 pypi_0 pypi
[conda] torchaudio 2.3.1 pypi_0 pypi
[conda] torchvision 0.18.1 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l @XilunWu

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions