Skip to content

ProcessGroupGloo reduce produces wrong result  #21480

@mrshenli

Description

@mrshenli

🐛 Bug

ProcessGroupGloo reduce results are wrong on non-root process.

To Reproduce

Run the following code:

import torch
import torch.multiprocessing as mp
import torch.distributed as c10d

import tempfile

def opts(threads=2):
    opts = c10d.ProcessGroupGloo.Options()
    opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
    opts.timeout = 5.0
    opts.threads = threads
    return opts

def reduce_process_gloo(rank, filename, world_size):
    store = c10d.FileStore(filename, world_size)
    pg = c10d.ProcessGroupGloo(store, rank, world_size, opts())
    x = torch.ones(2, 2).to(rank)
    pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
    print ("gloo rank ", rank, ": x = ", x)

def reduce_process_nccl(rank, filename, world_size):
    store = c10d.FileStore(filename, world_size)
    pg = c10d.ProcessGroupNCCL(store, rank, world_size)
    x = torch.ones(2, 2).to(rank)
    pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
    print ("nccl rank ", rank, ": x = ", x)

if __name__ == '__main__':
    with tempfile.NamedTemporaryFile(delete=False) as file:
        world_size = 2
        mp.spawn(reduce_process_gloo,
                 args=(file.name, world_size),
                 nprocs=world_size,
                 join=True)

    with tempfile.NamedTemporaryFile(delete=False) as file:
        world_size = 2
        mp.spawn(reduce_process_nccl,
                 args=(file.name, world_size),
                 nprocs=world_size,
                 join=True)

You will see output:

gloo rank  1 : x =  tensor([[1., 1.],
        [2., 2.]], device='cuda:1')
gloo rank  0 : x =  tensor([[2., 2.],
        [2., 2.]], device='cuda:0')
nccl rank  1 : x =  tensor([[1., 1.],
        [1., 1.]], device='cuda:1')
nccl rank  0 : x =  tensor([[2., 2.],
        [2., 2.]], device='cuda:0')

where gloo rank 1 should output [[1, 1], [1, 1]] as nccl does.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23

Metadata

Metadata

Assignees

No one assigned

    Labels

    high priorityoncall: distributedAdd this issue/PR to distributed oncall triage queuetriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions