Skip to content

Direct construction of FakeProcessGroup doesn't raise errors but has silent incorrectness #162129

@ezyang

Description

@ezyang

🐛 Describe the bug

Repro:

import torch
from torch.utils._python_dispatch import TorchDispatchMode
import torch.distributed as dist
from torch.distributed._distributed_c10d import FakeProcessGroup
from torch.testing._internal.distributed.fake_pg import FakeStore


class SimpleTensorMode(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        print(func)
        if kwargs is None:
            kwargs = {}
        return func(*args, **kwargs)


fake_pg = FakeProcessGroup(rank=0, world_size=3)
with SimpleTensorMode():
    tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    dist.all_reduce(tensor, group=fake_pg)

This prints

aten.lift_fresh.default

But it's expected to also print c10d.allreduce_.default

If you replace fake_pg with a proper initialized PG ala

        fake_store = FakeStore()
        torch.distributed.init_process_group(
            "fake", store=fake_store, rank=0, world_size=3
        )

it works.

We should error more loudly when you use FakeProcessGroup directly in this way.

Versions

main

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: c10dIssues/PRs related to collective communications and process groupsmodule: dispatchDispatchStub, Type, void pointer table, c10 dispatchoncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions