Skip to content

Commit 366bd1c

Browse files
committed
Update on "[dtensor][random] allow user to manual_seed different seed on device mesh; only sync RNG state in WORLD when manual_seed has not been called"
**Summary** This PR proposes 3 changes to DTensor RNG management: 1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`. 2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling ``` world_mesh = init_device_mesh( device_type="cuda", mesh_shape=(2, 2, 2), mesh_dim_names=("pp", "dp", "tp"), ) pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd") # this flattening is only needed if you need to call collective over this mesh torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh) ``` In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize. 3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen. **Motivation** tl;dr 1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel. 2. Users may want to set different seed on ranks in one device mesh. 3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it. see detail in #140301 **Test** `pytest test/distributed/_tensor/test_random_ops.py` `pytest test/distributed/tensor/parallel/test_tp_random_state.py` cc wanchaol tianyu-l wz337 d4l3k H-Huang awgu kwen2501 fegin fduwjj wconstab c-p-i-occ c-p-i-o [ghstack-poisoned]
2 parents 14be788 + 5754d68 commit 366bd1c

File tree

2 files changed

+130
-7
lines changed

2 files changed

+130
-7
lines changed

test/distributed/_tensor/test_random_ops.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@
66
import torch
77
import torch.distributed._functional_collectives as funcol
88
import torch.distributed.tensor._random as random
9+
from torch.distributed._composable.fsdp import fully_shard
910
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
1011
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
1112
from torch.distributed._tensor.api import distribute_tensor
1213
from torch.distributed._tensor.placement_types import Replicate, Shard
1314
from torch.distributed.distributed_c10d import broadcast_object_list
14-
from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed
15+
from torch.distributed.tensor._random import (
16+
is_rng_supported_mesh,
17+
manual_seed,
18+
OffsetBasedRNGTracker,
19+
)
1520
from torch.distributed.tensor.debug import CommDebugMode
21+
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
1622
from torch.testing._internal.common_utils import run_tests
1723
from torch.testing._internal.distributed._tensor.common_dtensor import (
1824
DTensorTestBase,
@@ -118,6 +124,19 @@ def test_manual_seed(self):
118124

119125
self.assertEqual(comm_mode.get_total_counts(), 0)
120126

127+
@with_comms
128+
@skip_unless_torch_gpu
129+
def test_manual_seed_submesh(self):
130+
# the current rank is not a part of the mesh
131+
single_rank_device_mesh = DeviceMesh(
132+
self.device_type, [(self.rank + 1) % self.world_size]
133+
)
134+
with self.assertRaisesRegex(
135+
RuntimeError,
136+
"manual_seed requires the current rank to be a part of the device mesh",
137+
):
138+
manual_seed(self.rank, single_rank_device_mesh)
139+
121140
@with_comms
122141
@skip_unless_torch_gpu
123142
def test_pipeline_parallel_manual_seed(self):
@@ -159,6 +178,102 @@ def test_pipeline_parallel_manual_seed(self):
159178
tensor_gather[2 * other_rank : 2 * (other_rank + 1), :],
160179
)
161180

181+
@with_comms
182+
@skip_unless_torch_gpu
183+
def test_tp_model_meta_init(self):
184+
# initialize the 1-d device mesh for TP
185+
tp_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
186+
187+
# model meta init
188+
with torch.device("meta"):
189+
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
190+
self.assertEqual(model.weight.device, torch.device("meta"))
191+
parallelize_module(model, tp_mesh, ColwiseParallel())
192+
if random._rng_tracker is not None:
193+
random._rng_tracker.distribute_region_enabled = True
194+
195+
self.assertEqual(model.weight.device, torch.device("meta"))
196+
197+
# actual initialization
198+
device = torch.device("cuda", torch.cuda.current_device())
199+
model.to_empty(device=device)
200+
model.reset_parameters()
201+
self.assertTrue(
202+
random._rng_tracker is not None
203+
and isinstance(random._rng_tracker, OffsetBasedRNGTracker)
204+
)
205+
self.assertEqual(model.weight.device, device)
206+
assert isinstance(model.weight, DTensor)
207+
208+
# gather all the shards to compare initialization results
209+
WORLD = torch.distributed.group.WORLD
210+
assert WORLD is not None
211+
weight_local = model.weight.to_local()
212+
weight_gather = funcol.all_gather_tensor(
213+
weight_local,
214+
gather_dim=0,
215+
group=WORLD,
216+
)
217+
218+
# verify the weights are initialized differently on all ranks
219+
for other_rank in range(self.world_size):
220+
if self.rank != other_rank:
221+
self.assertNotEqual(
222+
weight_local,
223+
weight_gather[other_rank : other_rank + 1, :],
224+
)
225+
226+
@with_comms
227+
@skip_unless_torch_gpu
228+
def test_fsdp_tp_model_meta_init(self):
229+
# initialize the 2-d device mesh
230+
global_mesh = init_device_mesh(
231+
self.device_type,
232+
mesh_shape=(self.world_size // 2, 2),
233+
mesh_dim_names=("dp", "tp"),
234+
)
235+
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
236+
237+
# model meta init
238+
with torch.device("meta"):
239+
model = torch.nn.Linear(self.world_size, self.world_size, bias=False)
240+
self.assertEqual(model.weight.device, torch.device("meta"))
241+
parallelize_module(model, tp_mesh, ColwiseParallel())
242+
if random._rng_tracker is not None:
243+
random._rng_tracker.distribute_region_enabled = True
244+
245+
fully_shard(model, mesh=dp_mesh)
246+
self.assertEqual(model.weight.device, torch.device("meta"))
247+
248+
# actual initialization
249+
device = torch.device("cuda", torch.cuda.current_device())
250+
model.to_empty(device=device)
251+
model.reset_parameters()
252+
self.assertTrue(
253+
random._rng_tracker is not None
254+
and isinstance(random._rng_tracker, OffsetBasedRNGTracker)
255+
)
256+
self.assertEqual(model.weight.device, device)
257+
assert isinstance(model.weight, DTensor)
258+
259+
# gather all the shards to compare initialization results
260+
WORLD = torch.distributed.group.WORLD
261+
assert WORLD is not None
262+
weight_local = model.weight.to_local()
263+
weight_gather = funcol.all_gather_tensor(
264+
weight_local,
265+
gather_dim=0,
266+
group=WORLD,
267+
)
268+
269+
# verify the weights are initialized differently on all ranks
270+
for other_rank in range(self.world_size):
271+
if self.rank != other_rank:
272+
self.assertNotEqual(
273+
weight_local,
274+
weight_gather[other_rank : other_rank + 1, :],
275+
)
276+
162277
@with_comms
163278
@skip_unless_torch_gpu
164279
def test_deterministic_dropout_1d(self):

torch/distributed/tensor/_random.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
5252
5353
Args:
5454
seed (int): The desired seed.
55-
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed.
55+
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is
56+
required that the ``device_mesh`` include the calling rank. This is
57+
to ensure that the SPMD region maintains a synchronous RNG state, which
58+
means no ranks should be initialized with values other than ``seed``.
5659
5760
Returns:
5861
None
@@ -62,7 +65,7 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
6265
ensure on their own that the value passed in is the desired ``seed`` for ranks
6366
within ``device_mesh``.
6467
If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
65-
``manual_seed`` will not set its GPU device's generator seed.
68+
``manual_seed`` will throw an error.
6669
Current implementation only supports a GPU device mesh.
6770
"""
6871
device_handle = _get_device_handle(device_mesh.device_type)
@@ -82,6 +85,12 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
8285
# the current rank is in mesh
8386
if device_mesh.get_coordinate() is not None:
8487
_rng_tracker._manual_seed(seed)
88+
else:
89+
raise RuntimeError(
90+
"manual_seed requires the current rank to be a part of the device mesh "
91+
"otherwise DTensor RNG state on the rank will not be initialized and "
92+
"the behavior of DTensor random ops is undefined."
93+
)
8594

8695

8796
class _RNGStateTracker:
@@ -130,8 +139,8 @@ def get_seed(self, name: str) -> int:
130139
return int(seed_tensor.item())
131140

132141
def set_seed(self, name: str, seed: int) -> None:
133-
seed_tensor = torch.tensor([seed]).view(torch.uint8)
134-
offset_tensor = torch.tensor([0]).view(torch.uint8)
142+
seed_tensor = torch.tensor([seed], device="cpu").view(torch.uint8)
143+
offset_tensor = torch.tensor([0], device="cpu").view(torch.uint8)
135144
self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
136145

137146
def _distribute_region(self, spec: DTensorSpec):
@@ -198,7 +207,7 @@ def set_offset(self, name: str, offset: int) -> None:
198207
)
199208

200209
seed_tensor = (self.rng_states[name])[0:8]
201-
offset_tensor = torch.tensor([offset]).view(torch.uint8)
210+
offset_tensor = torch.tensor([offset], device="cpu").view(torch.uint8)
202211
self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
203212

204213
def _set_pre_op_offset(self, spec: DTensorSpec) -> None:
@@ -277,7 +286,6 @@ def _set_pre_op_offset(self, spec: DTensorSpec) -> None:
277286
total_num_shards = 1
278287
# the tensor dim is sharded on more than 1 mesh dim
279288
if isinstance(mesh_dim, List):
280-
assert isinstance(mesh_dim, List)
281289
rank_coord = [mesh_coordinate[d] for d in mesh_dim]
282290
num_shards = [mesh_size[d] for d in mesh_dim]
283291
# compute the shard idx and total number of shards

0 commit comments

Comments
 (0)