Skip to content

Segmentation fault using all_reduce with cuda:1 (MPI) #21922

@Xyand

Description

@Xyand

🐛 Bug

Using all_reduce under cuda-aware MPI with a cuda device other than cuda:0 causes segmentation fault. I managed to bypass this for a very specific use case by setting CUDA_VISIBLE_DEVICES=1 and then using cuda:0 within pytorch.

To Reproduce

import os
import socket
import torch
import torch.distributed as dist


def run(rank, size):
    t = torch.rand(1).cuda()
    gather_t = [torch.ones_like(t) for _ in range(size)]
    dist.all_gather(gather_t, t)

def init_processes(rank, size, fn, backend='tcp'):
    """ Initialize the distributed environment. """
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])

    torch.cuda.set_device(1)
    init_processes(world_rank, world_size, run, backend='mpi')

Steps to reproduce the behavior:

mpirun -np 4 --oversubscribe -host 127.0.0.1 python test.py

Segmentation fault:

[prototype:02781] *** Process received signal ***
[prototype:02781] Signal: Segmentation fault (11)
[prototype:02781] Signal code: Invalid permissions (2)
[prototype:02781] Failing at address: 0x7f4e64a00000
[prototype:02781] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x11390)[0x7f4f088a8390]
[prototype:02781] [ 1] /lib/x86_64-linux-gnu/libc.so.6(+0x14e045)[0x7f4f0861b045]
[prototype:02781] [ 2] /opt/openmpi-3.0.0/lib/libopen-pal.so.40(+0x49eec)[0x7f4eb2f8ceec]
[prototype:02781] [ 3] /opt/openmpi-3.0.0/lib/libmpi.so.40(ompi_datatype_sndrcv+0x53a)[0x7f4ee0b4713a]
[prototype:02781] [ 4] /opt/openmpi-3.0.0/lib/libmpi.so.40(ompi_coll_base_allgather_intra_recursivedoubling+0x8f)[0x7f4ee0b898bf]
[prototype:02781] [ 5] /opt/openmpi-3.0.0/lib/libmpi.so.40(MPI_Allgather+0x12e)[0x7f4ee0b47e3e]
[prototype:02781] [ 6] /root/anaconda/lib/python3.6/site-packages/torch/lib/libtorch_python.so(+0x71d826)[0x7f4ef9818826]
[prototype:02781] [ 7] /root/anaconda/lib/python3.6/site-packages/torch/lib/libtorch_python.so(_ZN4c10d15ProcessGroupMPI7runLoopEv+0x27c)[0x7f4ef9814c6c]
[prototype:02781] [ 8] /root/anaconda/lib/libstdc++.so.6(+0xb8678)[0x7f4ee259e678]
[prototype:02781] [ 9] /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba)[0x7f4f0889e6ba]
[prototype:02781] [10] /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d)[0x7f4f085d441d]
[prototype:02781] *** End of error message ***

Expected behavior

No crash

Environment

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.14.0

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Tesla P100-PCIE-16GB
GPU 1: Tesla P100-PCIE-16GB

Nvidia driver version: 418.67
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.16.4
[pip3] numpydoc==0.8.0
[pip3] torch==1.1.0
[pip3] torchvision==0.3.0a0+c94a158
[conda] blas 1.0 mkl
[conda] libmklml 2018.0.3 0
[conda] magma-cuda100 2.5.0 1 pytorch
[conda] mkl 2019.4 243
[conda] mkl-dnn 0.14 h6bb024c_0
[conda] mkl-include 2019.4 243
[conda] mkl-service 2.0.2 py36h7b6447c_0
[conda] mkl_fft 1.0.12 py36ha843d7b_0
[conda] mkl_random 1.0.2 py36hd81dba3_0
[conda] torch 1.1.0 pypi_0 pypi
[conda] torchvision 0.3.0a0+c94a158 pypi_0 pypi

MPI 3.0.0 - cuda aware

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generaloncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions