Skip to content

Commit 67fbb31

Browse files
Revert "[FSDP2] support dataclass args/kwargs output (#173415)"
This reverts commit 63c5a68. Reverted #173415 on behalf of https://github.com/weifengpy due to failing internal test. revert first and find root cause later ([comment](#173415 (comment)))
1 parent d7f447c commit 67fbb31

File tree

3 files changed

+35
-127
lines changed

3 files changed

+35
-127
lines changed

test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py

Lines changed: 11 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -597,133 +597,28 @@ def assert_fn(output: torch.Tensor):
597597
loss.backward()
598598

599599
@skip_if_lt_x_gpu(1)
600-
def test_dataclass_input_output(self):
601-
from unittest.mock import patch
602-
603-
from torch.distributed._composable_state import _get_module_state
604-
600+
def test_dataclass_input(self):
605601
@dataclasses.dataclass
606602
class Input:
607603
x: torch.Tensor
608-
y: torch.Tensor
609-
610-
@dataclasses.dataclass
611-
class Output:
612-
x: torch.Tensor
613-
y: torch.Tensor
614-
615-
@dataclasses.dataclass
616-
class Scale:
617-
factor: torch.Tensor
618604

619605
class Model(nn.Module):
620606
def __init__(self, *args, **kwargs) -> None:
621607
super().__init__(*args, **kwargs)
622608
self._layer = nn.Linear(10, 10)
623609

624-
def forward(self, input: Input, *, scale: Scale | None = None):
625-
x = self._layer(input.x)
626-
y = self._layer(input.y)
627-
if scale is not None:
628-
x = x * scale.factor
629-
y = y * scale.factor
630-
return Output(x=x, y=y)
631-
632-
class TensorModel(nn.Module):
633-
def __init__(self, *args, **kwargs) -> None:
634-
super().__init__(*args, **kwargs)
635-
self._layer = nn.Linear(10, 10)
610+
def forward(self, input: Input):
611+
return self._layer(input.x)
636612

637-
def forward(self, x: torch.Tensor, *, scale: torch.Tensor | None = None):
638-
out = self._layer(x)
639-
if scale is not None:
640-
out = out * scale
641-
return out
642-
643-
# Test with different MixedPrecisionPolicy configurations
644-
mp_policies = [
645-
MixedPrecisionPolicy(
646-
param_dtype=torch.bfloat16,
647-
reduce_dtype=torch.bfloat16,
648-
),
649-
MixedPrecisionPolicy(
650-
param_dtype=torch.bfloat16,
651-
reduce_dtype=torch.float32,
652-
),
653-
]
654-
655-
for mp_policy in mp_policies:
656-
# Test with normal torch.Tensor as arg
657-
tensor_model = TensorModel()
658-
fully_shard(tensor_model, mp_policy=mp_policy)
659-
fsdp_state = _get_module_state(tensor_model)
660-
x = torch.randn(10, 10, device=device_type, requires_grad=True)
661-
with patch.object(
662-
fsdp_state, "_pre_backward", wraps=fsdp_state._pre_backward
663-
) as mock_pre_backward:
664-
loss = tensor_model(x).sum()
665-
loss.backward()
666-
mock_pre_backward.assert_called()
667-
668-
# Test with normal torch.Tensor as both arg and kwarg
669-
tensor_model.zero_grad()
670-
x = torch.randn(10, 10, device=device_type, requires_grad=True)
671-
scale = torch.randn(10, 10, device=device_type, requires_grad=True)
672-
with patch.object(
673-
fsdp_state, "_pre_backward", wraps=fsdp_state._pre_backward
674-
) as mock_pre_backward:
675-
loss = tensor_model(x, scale=scale).sum()
676-
loss.backward()
677-
mock_pre_backward.assert_called()
678-
679-
# Test with dataclass as positional arg only
680-
model = nn.Sequential(*[Model(), Model()])
681-
inp = Input(
682-
x=torch.randn(10, 10, device=device_type, requires_grad=True),
683-
y=torch.randn(10, 10, device=device_type, requires_grad=True),
684-
)
613+
mp_policy = MixedPrecisionPolicy(
614+
torch.bfloat16, torch.bfloat16, torch.bfloat16, True
615+
)
616+
model = Model()
617+
inp = Input(torch.randn(2, 10).to(device_type))
685618

686-
for layer in model:
687-
fully_shard(layer, mp_policy=mp_policy, reshard_after_forward=True)
688-
fully_shard(model, mp_policy=mp_policy, reshard_after_forward=True)
689-
690-
# Patch _pre_backward on all FSDP states
691-
layer0_state = _get_module_state(model[0])
692-
layer1_state = _get_module_state(model[1])
693-
root_state = _get_module_state(model)
694-
with (
695-
patch.object(
696-
layer0_state, "_pre_backward", wraps=layer0_state._pre_backward
697-
) as layer0_mock,
698-
patch.object(
699-
layer1_state, "_pre_backward", wraps=layer1_state._pre_backward
700-
) as layer1_mock,
701-
patch.object(
702-
root_state, "_pre_backward", wraps=root_state._pre_backward
703-
) as root_mock,
704-
):
705-
output = model(inp)
706-
loss = output.x.sum() + output.y.sum()
707-
loss.backward()
708-
layer0_mock.assert_called()
709-
layer1_mock.assert_called()
710-
root_mock.assert_called()
711-
712-
# Test with dataclass as both positional arg and kwarg
713-
inp = Input(
714-
x=torch.randn(10, 10, device=device_type, requires_grad=True),
715-
y=torch.randn(10, 10, device=device_type, requires_grad=True),
716-
)
717-
scale = Scale(
718-
factor=torch.randn(10, 10, device=device_type, requires_grad=True)
719-
)
720-
with patch.object(
721-
layer0_state, "_pre_backward", wraps=layer0_state._pre_backward
722-
) as layer0_mock:
723-
output = model[0](inp, scale=scale)
724-
loss = output.x.sum() + output.y.sum()
725-
loss.backward()
726-
layer0_mock.assert_called()
619+
fully_shard(model, mp_policy=mp_policy)
620+
loss = model(inp).sum()
621+
loss.backward()
727622

728623

729624
if __name__ == "__main__":

torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# mypy: allow-untyped-defs
22
import contextlib
3-
import functools
43
import logging
54
from collections.abc import Callable
65
from typing import Any, cast, NamedTuple
@@ -11,8 +10,8 @@
1110
from torch.distributed.device_mesh import _get_device_handle
1211
from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates
1312
from torch.distributed.tensor import Shard
14-
from torch.distributed.utils import _apply_to_tensors
1513
from torch.profiler import record_function
14+
from torch.utils._pytree import tree_flatten, tree_unflatten
1615
from torch.utils.hooks import RemovableHandle
1716

1817
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
@@ -715,11 +714,24 @@ def _register_post_backward_hook(
715714
return args, kwargs
716715
if not torch.is_grad_enabled():
717716
return args, kwargs
718-
register_post_backward_func = functools.partial(
719-
RegisterPostBackwardFunction.apply, self
720-
)
721-
args = _apply_to_tensors(lambda t: register_post_backward_func(t)[0], args)
722-
kwargs = _apply_to_tensors(lambda t: register_post_backward_func(t)[0], kwargs)
717+
args_list, args_spec = tree_flatten(args)
718+
kwargs_list, kwargs_spec = tree_flatten(kwargs)
719+
args_kwargs_list = list(args_list) + list(kwargs_list)
720+
inp_tensor_indices: list[int] = []
721+
inp_tensors: list[torch.Tensor] = []
722+
for i, obj in enumerate(args_kwargs_list):
723+
if torch.is_tensor(obj) and obj.requires_grad:
724+
inp_tensor_indices.append(i)
725+
inp_tensors.append(obj)
726+
if len(inp_tensors) == 0:
727+
return args, kwargs # no tensors that require gradients
728+
inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors)
729+
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
730+
args_kwargs_list[inp_tensor_idx] = inp_tensor
731+
args_list = args_kwargs_list[: len(args_list)]
732+
kwargs_list = args_kwargs_list[len(args_list) :]
733+
args = tree_unflatten(args_list, args_spec)
734+
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
723735
return args, kwargs
724736

725737
def _register_state_dict_hooks(self) -> None:

torch/distributed/fsdp/_fully_shard/_fsdp_state.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from torch.distributed.device_mesh import _get_device_handle
1919
from torch.distributed.utils import _apply_to_tensors, _to_kwargs
20+
from torch.utils._pytree import tree_flatten
2021

2122
from ._fsdp_api import MixedPrecisionPolicy
2223
from ._fsdp_common import (
@@ -349,10 +350,10 @@ def _finalize_backward(self) -> None:
349350
def _register_pre_backward_hook(self, output: Any) -> Any:
350351
if not torch.is_grad_enabled():
351352
return output
352-
_apply_to_tensors(
353-
lambda x: x.register_hook(self._pre_backward) if x.requires_grad else x,
354-
output,
355-
)
353+
flat_outputs, _ = tree_flatten(output)
354+
for t in flat_outputs:
355+
if torch.is_tensor(t) and t.requires_grad:
356+
t.register_hook(self._pre_backward)
356357
return output
357358

358359
def _register_root_post_backward_final_callback(self):

0 commit comments

Comments
 (0)