-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP2] Added shard_placement_fn arg
#136221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
This is WIP and quite messy right now. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP and quite messy right now. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP and quite messy right now. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP and quite messy right now. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP and quite messy right now. 2D train parity seems broken when using new code path cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP and quite messy right now. 2D train parity seems broken when using new code path cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
This is WIP. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657) [ghstack-poisoned]
This is WIP. For `Shard(i)` and `i != 0`, the uneven sharding case must incur extra copies before/after all-gather/reduce-scatter in order for the sharded parameters and gradients to have contiguous strides. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657) [ghstack-poisoned]
This is WIP. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657) [ghstack-poisoned]
This is WIP. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657) [ghstack-poisoned]
This is WIP. cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657) [ghstack-poisoned]
|
To-do:
|
| self.padded_sharded_param_size = padded_sharded_param.size() | ||
| if sharded_param.numel() > 0: | ||
| padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param) | ||
| padded_sharded_param.narrow( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was worried about NF4 but found that .narrow and [:] both dispatch to slice.Tensor. That's good! any ideas on how to find the mapping (other than printing torch dispatch)? I am a litle bit surpised since I found narrow in native_functions.yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch
from torch.utils._python_dispatch import TorchDispatchMode
class LoggingMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Print the function being dispatched and the arguments
print(f"Dispatching function: {func.__name__}")
return func(*args, **kwargs)
# Example usage
data = torch.rand(3, 3, device="cuda")
# Use LoggingMode for the duration of this block
with LoggingMode():
data.narrow(dim=1, start=0, length=2)
data[:2]
| unsharded_grad.size(shard_dim) % world_size == 0 | ||
| ), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" | ||
| chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim) | ||
| unsharded_grads[i] = torch.cat(chunks, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we chunk + cat to make unsharded_grads contiguous? does the .cat trigger copies on gpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to chunk -> cat to do a data shuffle. If I have some time, I may try to draw some diagrams to add to the PR description.
| if fsdp_param.fsdp_placement.dim != 0: | ||
| # Copy to a temporary and then chunk-cat into the final all-gather | ||
| # output tensors | ||
| param_all_gather_outputs = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[old] seems it's overwriting init_all_gather_outputs. is it better to move the logic into init_all_gather_outputs ? or this is only temporary ?
editted: sorry. it's not overwriting. it's temporary place for copy out. feel free to ignore
| post_param_size[shard_dim] *= world_size | ||
| cat_out = target_all_gather_output.view(post_param_size) | ||
| torch.cat(chunks, dim=shard_dim, out=cat_out) | ||
| torch._C._autograd._unsafe_set_version_counter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this similar to with _unsafe_preserve_version_counter(target_all_gather_output)? we use _unsafe_set_version_counter because param_all_gather_outputs is defined in for ... in and the api is easier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes with _unsafe_preserve_version_counter(target_all_gather_output) calls into this _unsafe_set_version_counter
since we know we just need to decrement once, I think it is simpler/lower overhead to just call the API directly
| dp_shard_tp_placement = ( | ||
| ( | ||
| _StridedShard(0, split_factor=split_factor) | ||
| _StridedShard(shard_dim, split_factor=split_factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never truly understand strided sharding. I guess it's meant for model.parameters() outside of fwd/bwd ? mostly for DTensor ops outside of FSDP, like state dict, optimizer, grad norm clipping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep
## TL;DR
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.
```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
return Shard(largest_dim)
fully_shard(module, shard_placement_fn=shard_placement_fn)
```
## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.
cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o
Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)
[ghstack-poisoned]
## TL;DR
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.
```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
return Shard(largest_dim)
fully_shard(module, shard_placement_fn=shard_placement_fn)
```
## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.
cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o
Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)
[ghstack-poisoned]
|
re-opened in a different PR (#137496) since the test-config/distributed label seems to persist even after removing |
Stack from ghstack (oldest at bottom):
shard_placement_fnarg #136221TL;DR
This PR adds a
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]arg tofully_shardthat allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.Follow-Ups
cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o
Differential Revision: D62964657