Skip to content

Does FSDP support nested wrapping for MoE models with Expert Parallelism? #149396

@zigzagcai

Description

@zigzagcai

Hi,

I am trying to use FSDP with Expert Parallelism to tackle with training MoE models, which size is quite large (670B DeepSeek v3 for example). Since even we use fully sharded options , we will encounter CUDA OOM during training. The root cause is per-layer parameter size is quite large. Therefore we implement Expert Parallelism.

However, the process group for MoE part (Expert Parallelism) and non-MoE part is not the same. So we need to wrap MoE part and non-MoE part separately. The detailed information of FSDP+ EP can be found here: #114361

I tried to wrap the model according to the suggestion from @awgu

            ignored_mod = []
            for layer_id, layer in enumerate(model.layers):
                if layer_id >= config.first_k_dense_replace:
                    layer.feed_forward.moe_layer.experts = FSDP(
                        layer.feed_forward.moe_layer.experts, 
                        process_group=expert_data_process_group,
                        sharding_strategy=ShardingStrategy.FULL_SHARD, 
                        forward_prefetch=True,
                        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
                        limit_all_gathers=True,
                        use_orig_params=True,
                    )
                    ignored_mod.append(layer.feed_forward.moe_layer.experts)
            model = FSDP(
                module=model,
                process_group=data_process_group,
                sharding_strategy=ShardingStrategy.FULL_SHARD,
                auto_wrap_policy=ModuleWrapPolicy(wrap_cls),
                forward_prefetch=True,
                backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
                limit_all_gathers=True,
                use_orig_params=True,
                ignored_modules=ignored_mod,
            )

But it seems that FSDP cannot support nested wrapping with two process_groups. (one for non-MoE parts and another one for MoE experts )

  File "/blahblah/zigzagcai/.conda/envs/my_dev_env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
    _auto_wrap(
  File "/blahblah/zigzagcai/.conda/envs/my_dev_env/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 45, in _auto_wrap
    _check_nested_wrapping(root_module)
  File "/blahblah/zigzagcai/.conda/envs/my_dev_env/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 107, in _check_nested_wrapping
    raise ValueError(
ValueError: FSDP auto wrapping requires modules to not already have FSDP applied but found model.layers.1.feed_forward.moe_layer.experts in

And I cannot even put the wrapped InnerFSDP modules in the ignore_modules list, when we tried to materialize outerFSDP module.

  File "/blahblah/zigzagcai/.conda/envs/my_dev_env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 442, in __init__
    _init_ignored_module_states(self, module, ignored_modules, ignored_states)
  File "/blahblah/zigzagcai/.conda/envs/my_dev_env/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 314, in _init_ignored_module_states
    state._ignored_modules = _get_ignored_modules(module, ignored_modules)
  File "/blahblah/zigzagcai/.conda/envs/my_dev_env/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 697, in _get_ignored_modules
    raise ValueError("`ignored_modules` should not include FSDP modules")
ValueError: `ignored_modules` should not include FSDP modules

Then I check with the FSDP source code, and I found the above assertion is on the relaxation TODO list:
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_init_utils.py#L680-L683
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_wrap_utils.py#L43-L45

So I removed the two assertions and the training runs successfully. So, the question is does FSDP support nested warpping, that is:
(1) Firstly, we wrap MoE expert part with expert_data_process_group, and put the wrapped expert parts into the ignored_modules
(2) Then, we wrap the non-MoE part with data_process_group.
Does my implementation right for this case since the two assertion is removed?

Thanks in advance if anybody could provide some insights!

cc
@awgu @zhaojuanmao @rohan-varma @liangluofb @fegin @lessw2020 @mrshenli @penguinwu @kwen2501

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @kwen2501 @c-p-i-o

Metadata

Metadata

Assignees

Labels

module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions