-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Open
Labels
module: nnRelated to torch.nnRelated to torch.nnmodule: xlaRelated to XLA supportRelated to XLA supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
torch.distributed.init_process_group(
backend="xla",
init_method="xla://",
)
m = torch.nn.Linear(10, 10).to(xm.xla_device())
m = torch.nn.parallel.DistributedDataParallel(m, device_ids=[0])
# Error happens here!
m.half()[rank0]: Traceback (most recent call last):
[rank0]: File "torch/nn/modules/module.py", line 824, in _apply
[rank0]: torch.utils.swap_tensors(param, param_applied)
[rank0]: File "torch/utils/__init__.py", line 72, in swap_tensors
[rank0]: check_use_count(t1, 't1')
[rank0]: File "torch/utils/__init__.py", line 70, in check_use_count
[rank0]: raise RuntimeError(error_str)
[rank0]: RuntimeError: Expected use_count of t1 to be 1 or 2 with an AccumulateGrad node but got 4 make sure you are not holding references to the tensor in other places.
[rank0]: The above exception was the direct cause of the following exception:
[rank0]: Traceback (most recent call last):
[rank0]: File "examples/scratch.py", line 58, in <module>
[rank0]: m.half()
[rank0]: File "torch/nn/modules/module.py", line 1017, in half
[rank0]: return self._apply(lambda t: t.half() if t.is_floating_point() else t)
[rank0]: File "torch/nn/modules/module.py", line 779, in _apply
[rank0]: module._apply(fn)
[rank0]: File "torch/nn/modules/module.py", line 828, in _apply
[rank0]: raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e
[rank0]: RuntimeError: _apply(): Couldn't swap Linear.weightVersions
- PyTorch version: f532854
- PyTorch/XLA version: aec273056a95d8119279c15d36c0f48f739fb810
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @bdhirsh @miladm @JackCaoG
Metadata
Metadata
Assignees
Labels
module: nnRelated to torch.nnRelated to torch.nnmodule: xlaRelated to XLA supportRelated to XLA supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module