Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Sep 17, 2024

Stack from ghstack (oldest at bottom):

TL;DR

This PR adds a shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]] arg to fully_shard that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)

Follow-Ups

  • Copy kernels: For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.

cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Differential Revision: D62964657

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136221

Note: Links to docs will display an error until the docs builds have been completed.

❌ 40 Cancelled Jobs

As of commit 4942493 with merge base d1b87e2 (image):

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Sep 17, 2024
This is WIP and quite messy right now.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This is WIP and quite messy right now.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This is WIP and quite messy right now.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This is WIP and quite messy right now.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This is WIP and quite messy right now.

2D train parity seems broken when using new code path

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This is WIP and quite messy right now.

2D train parity seems broken when using new code path

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This is WIP.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This is WIP.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
This is WIP.

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Sep 18, 2024
ghstack-source-id: 60eedca
Pull Request resolved: #136221
awgu pushed a commit that referenced this pull request Oct 4, 2024
ghstack-source-id: acf7973
Pull Request resolved: #136221
This is WIP.


cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)

[ghstack-poisoned]
This is WIP.


For `Shard(i)` and `i != 0`, the uneven sharding case must incur extra copies before/after all-gather/reduce-scatter in order for the sharded parameters and gradients to have contiguous strides.


cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)

[ghstack-poisoned]
This is WIP.


cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)

[ghstack-poisoned]
This is WIP.


cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)

[ghstack-poisoned]
This is WIP.


cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Oct 7, 2024
ghstack-source-id: 04abf49
Pull Request resolved: #136221
@awgu
Copy link
Collaborator Author

awgu commented Oct 7, 2024

To-do:

  • Add some state dict test
  • (Optional) add some memory test (will use more memory though)
  • Test on 8-GPU devgpu in torchtitan and sanity check profiler trace

self.padded_sharded_param_size = padded_sharded_param.size()
if sharded_param.numel() > 0:
padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param)
padded_sharded_param.narrow(
Copy link
Contributor

@weifengpy weifengpy Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was worried about NF4 but found that .narrow and [:] both dispatch to slice.Tensor. That's good! any ideas on how to find the mapping (other than printing torch dispatch)? I am a litle bit surpised since I found narrow in native_functions.yaml

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from torch.utils._python_dispatch import TorchDispatchMode


class LoggingMode(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        # Print the function being dispatched and the arguments
        print(f"Dispatching function: {func.__name__}")
        return func(*args, **kwargs)

# Example usage
data = torch.rand(3, 3, device="cuda")

# Use LoggingMode for the duration of this block
with LoggingMode():
    data.narrow(dim=1, start=0, length=2)
    data[:2]

unsharded_grad.size(shard_dim) % world_size == 0
), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
unsharded_grads[i] = torch.cat(chunks, dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we chunk + cat to make unsharded_grads contiguous? does the .cat trigger copies on gpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to chunk -> cat to do a data shuffle. If I have some time, I may try to draw some diagrams to add to the PR description.

if fsdp_param.fsdp_placement.dim != 0:
# Copy to a temporary and then chunk-cat into the final all-gather
# output tensors
param_all_gather_outputs = [
Copy link
Contributor

@weifengpy weifengpy Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[old] seems it's overwriting init_all_gather_outputs. is it better to move the logic into init_all_gather_outputs ? or this is only temporary ?

editted: sorry. it's not overwriting. it's temporary place for copy out. feel free to ignore

post_param_size[shard_dim] *= world_size
cat_out = target_all_gather_output.view(post_param_size)
torch.cat(chunks, dim=shard_dim, out=cat_out)
torch._C._autograd._unsafe_set_version_counter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this similar to with _unsafe_preserve_version_counter(target_all_gather_output)? we use _unsafe_set_version_counter because param_all_gather_outputs is defined in for ... in and the api is easier?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes with _unsafe_preserve_version_counter(target_all_gather_output) calls into this _unsafe_set_version_counter
since we know we just need to decrement once, I think it is simpler/lower overhead to just call the API directly

dp_shard_tp_placement = (
(
_StridedShard(0, split_factor=split_factor)
_StridedShard(shard_dim, split_factor=split_factor)
Copy link
Contributor

@weifengpy weifengpy Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never truly understand strided sharding. I guess it's meant for model.parameters() outside of fwd/bwd ? mostly for DTensor ops outside of FSDP, like state dict, optimizer, grad norm clipping?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

## TL;DR
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)
```

## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. yifuwang  has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.






cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)

[ghstack-poisoned]
## TL;DR
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.

```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
    largest_dim = largest_dim_size = -1
    for dim, dim_size in enumerate(param.shape):
        if dim_size > largest_dim_size:
            largest_dim = dim
            largest_dim_size = dim_size
    return Shard(largest_dim)

fully_shard(module, shard_placement_fn=shard_placement_fn)
```

## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. yifuwang  has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.






cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

Differential Revision: [D62964657](https://our.internmc.facebook.com/intern/diff/D62964657)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Oct 8, 2024
ghstack-source-id: 71ac63d
Pull Request resolved: #136221
@awgu
Copy link
Collaborator Author

awgu commented Oct 8, 2024

re-opened in a different PR (#137496) since the test-config/distributed label seems to persist even after removing

@awgu awgu closed this Oct 8, 2024
@github-actions github-actions bot deleted the gh/awgu/640/head branch November 21, 2024 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants