-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2
Description
torch-only repro:
diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py
index 91fbc396f8e..09a2bf8f183 100644
--- a/test/distributed/_tensor/test_dtensor_compile.py
+++ b/test/distributed/_tensor/test_dtensor_compile.py
@@ -544,12 +544,18 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
def test_dynamo_dtensor_from_local_redistribute(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ from torch.distributed._functional_collectives import AsyncCollectiveTensor
# pass in tensor as inputs/outputs, create DTensor and run redistribute
# (allgather collective) inside the fn
def fn(x):
dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
- return dt.redistribute(mesh, [Replicate()]).to_local() + 2
+ out = dt.redistribute(mesh, [Replicate()], async_op=True).to_local()
+ return out
+ if isinstance(out, AsyncCollectiveTensor):
+ return out.wait()
+ else:
+ return out
x = torch.ones(1)
ref = fn(x)
# run with `python test/distributed/_tensor/test_dtensor_compile.py -k test_dynamo_dtensor_from_local_redistribute`
This fails with:
File "/home/hirsheybar/local/a/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 167, in run
self.env[node] = self.run_node(node)
File "/home/hirsheybar/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6670, in run_node
result = super().run_node(n)
File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 228, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 332, in call_method
return getattr(self_obj, target)(*args_tail, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='<torch._dynamo.testing.CompileCounterWithBackend object at 0x7f54e584b1c0>' raised:
AttributeError: 'FunctionalTensor' object has no attribute 'wait'
While executing %wait : [num_users=1] = call_method[target=wait](args = (%out,), kwargs = {})
Original traceback:
File "/home/hirsheybar/local/a/pytorch/test/distributed/_tensor/test_dtensor_compile.py", line 586, in fn
return out.wait()
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu
Metadata
Metadata
Assignees
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2