-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: c10dIssues/PRs related to collective communications and process groupsIssues/PRs related to collective communications and process groupsmodule: dispatchDispatchStub, Type, void pointer table, c10 dispatchDispatchStub, Type, void pointer table, c10 dispatchoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
🐛 Describe the bug
Repro:
import torch
from torch.utils._python_dispatch import TorchDispatchMode
import torch.distributed as dist
from torch.distributed._distributed_c10d import FakeProcessGroup
from torch.testing._internal.distributed.fake_pg import FakeStore
class SimpleTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
print(func)
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
fake_pg = FakeProcessGroup(rank=0, world_size=3)
with SimpleTensorMode():
tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
dist.all_reduce(tensor, group=fake_pg)
This prints
aten.lift_fresh.default
But it's expected to also print c10d.allreduce_.default
If you replace fake_pg with a proper initialized PG ala
fake_store = FakeStore()
torch.distributed.init_process_group(
"fake", store=fake_store, rank=0, world_size=3
)
it works.
We should error more loudly when you use FakeProcessGroup directly in this way.
Versions
main
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta
Metadata
Metadata
Assignees
Labels
module: c10dIssues/PRs related to collective communications and process groupsIssues/PRs related to collective communications and process groupsmodule: dispatchDispatchStub, Type, void pointer table, c10 dispatchDispatchStub, Type, void pointer table, c10 dispatchoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue