Skip to content

Dynamo doesn't handle branching on AsyncCollectiveTensor well #142076

@bdhirsh

Description

@bdhirsh

See sgl-project/sglang#2352

torch-only repro:

diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py
index 91fbc396f8e..09a2bf8f183 100644
--- a/test/distributed/_tensor/test_dtensor_compile.py
+++ b/test/distributed/_tensor/test_dtensor_compile.py
@@ -544,12 +544,18 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):

     def test_dynamo_dtensor_from_local_redistribute(self):
         mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        from torch.distributed._functional_collectives import AsyncCollectiveTensor

         # pass in tensor as inputs/outputs, create DTensor and run redistribute
         # (allgather collective) inside the fn
         def fn(x):
             dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
-            return dt.redistribute(mesh, [Replicate()]).to_local() + 2
+            out = dt.redistribute(mesh, [Replicate()], async_op=True).to_local()
+            return out
+            if isinstance(out, AsyncCollectiveTensor):
+                return out.wait()
+            else:
+                return out

         x = torch.ones(1)
         ref = fn(x)
         
# run with `python test/distributed/_tensor/test_dtensor_compile.py -k test_dynamo_dtensor_from_local_redistribute`

This fails with:

  File "/home/hirsheybar/local/a/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
  File "/home/hirsheybar/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6670, in run_node
    result = super().run_node(n)
  File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 228, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 332, in call_method
    return getattr(self_obj, target)(*args_tail, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='<torch._dynamo.testing.CompileCounterWithBackend object at 0x7f54e584b1c0>' raised:
AttributeError: 'FunctionalTensor' object has no attribute 'wait'

While executing %wait : [num_users=1] = call_method[target=wait](args = (%out,), kwargs = {})
Original traceback:
  File "/home/hirsheybar/local/a/pytorch/test/distributed/_tensor/test_dtensor_compile.py", line 586, in fn
    return out.wait()

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions