Skip to content

Commit 3406ac2

Browse files
XilunWupytorchmergebot
authored andcommitted
[BE] fix circular import in torch/distributed/utils.py (#136286)
**Summary** Fix circular import in `torch/distributed/utils.py` found when running internal test, see D62901023. Curious why this wasn't causing any issue. Is this relevant code deprecated and no longer used? Pull Request resolved: #136286 Approved by: https://github.com/Skylion007
1 parent 3bc073d commit 3406ac2

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

torch/distributed/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import torch
1919
import torch.distributed as dist
2020
from torch import nn
21-
from torch.nn.parallel._functions import _get_stream
22-
from torch.nn.parallel.scatter_gather import _is_namedtuple
2321
from torch.nn.utils.rnn import PackedSequence
2422

2523

@@ -123,6 +121,9 @@ def to_map(obj):
123121
device_mod = getattr(torch, device.type, None)
124122
if device.type == "cpu" or device_mod is None:
125123
return (obj.to(target_device),)
124+
125+
from torch.nn.parallel._functions import _get_stream
126+
126127
# Perform CPU -> target_device copies in a background stream. This code is
127128
# motivated from similar logic in torch/nn/parallel/_functions.py
128129
stream = _get_stream(target_device)
@@ -141,6 +142,9 @@ def to_map(obj):
141142
assert isinstance(output, torch.Tensor)
142143
output.record_stream(current_stream) # type: ignore[arg-type]
143144
return (output,)
145+
146+
from torch.nn.parallel.scatter_gather import _is_namedtuple
147+
144148
if _is_namedtuple(obj):
145149
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
146150
if isinstance(obj, tuple) and len(obj) > 0:
@@ -228,6 +232,8 @@ def _apply_to_tensors(fn, container):
228232
"""Recursively apply to all tensor in different kinds of container types."""
229233

230234
def apply(x):
235+
from torch.nn.parallel.scatter_gather import _is_namedtuple
236+
231237
if isinstance(x, torch.Tensor):
232238
return fn(x)
233239
elif hasattr(x, "__dataclass_fields__"):

0 commit comments

Comments
 (0)