Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Sep 9, 2025

Stack from ghstack (oldest at bottom):

Motivation

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

Candidate Approaches

  1. Ask users to write a module to wrap FlexAttention/SDPA and use parallelize_module to install a forward hook.

    • Pros: This is similar to how we do TP.
    • Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context manager is still required for splitting the inputs.
  2. Provide a function wrapper.

    • Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
    • Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

Summary

This PR implements approach 2 and refactor the code in such a way that most code can be used by option approach 1, which will be introduced in another PR.

We changed this PR to implement option 1 as people like option 1 due to the consistency with the existing parallelisms. But this PR can also serve the foundation to implement option 2, which was the early version of this PR.

This PR also changes create_cp_block_mask logic since we now only focus on ModuleWrapper approach which doesn't require to hack the seq_len field in a BlockMask.

This PR also removes TorchFunctionMode dispatcher mode as it doesn't work well with SAC.

cc @H-Huang @awgu @wanchaol @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162542

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a18e4eb with merge base d41aa18 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 9, 2025
fegin added a commit that referenced this pull request Sep 9, 2025
**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module.


ghstack-source-id: 0142083
Pull-Request-resolved: #162542
@fegin fegin changed the title [CP] Introduce flex_attention_wrapper [CP][WIP] Introduce flex_attention_wrapper Sep 9, 2025
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 10, 2025
**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module.

ghstack-source-id: 0355010
Pull-Request-resolved: #162542
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 10, 2025
**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module.

ghstack-source-id: 0355010
Pull-Request-resolved: #162542
fegin added a commit to pytorch/torchtitan that referenced this pull request Sep 10, 2025
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 10, 2025
**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module.

ghstack-source-id: ba158c5
Pull-Request-resolved: #162542
fegin added a commit to pytorch/torchtitan that referenced this pull request Sep 12, 2025
Similar to #1696, but this PR uses parallel_module similar to TP/SP.

This PR also requires pytorch/pytorch#162542
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 12, 2025
**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module.

ghstack-source-id: 9f6f679
Pull-Request-resolved: #162542
@fegin fegin changed the title [CP][WIP] Introduce flex_attention_wrapper [CP][WIP] Introduce flex_attention_wrapper and ContextParallel parallel plan Sep 12, 2025
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 15, 2025
**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module.

ghstack-source-id: 1658e02
Pull-Request-resolved: #162542
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 16, 2025
**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module.

ghstack-source-id: e51d35b
Pull-Request-resolved: #162542
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 16, 2025
**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module.

ghstack-source-id: 39b417a
Pull-Request-resolved: #162542
[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Oct 9, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Oct 10, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

fegin added 2 commits October 9, 2025 23:27
[ghstack-poisoned]
[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Oct 10, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 13, 2025
…164500)

`context_parallel()` being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to:

1. Wrap the attention op into an `nn.Module`
2. Lift any buffers that are not sequence agnostic to input

We can replace `context_parallel()` with two functional APIs: `_context_parallel_shard` and `_enable_context_parallel_dispatcher`.

Pull Request resolved: #164500
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542
# TODO: Reverify atol and rtol after
# https://github.com/pytorch/pytorch/pull/163185 is landed. The accuracy
# issue happens on the gradients.
torch.use_deterministic_algorithms(True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't do things like this. It might subtly break other tests since this globally changes the deterministic setting. Please only set deterministic on the tests you need it for, and use a context manager so that it resets back to prior state.

pytorchmergebot pushed a commit that referenced this pull request Oct 13, 2025
The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter.  While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd.

For the next step, we should explore how to interpolate the required communication based on the information from BlockMask.

Pull Request resolved: #163185
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542, #164500
pytorchmergebot pushed a commit that referenced this pull request Oct 14, 2025
…5039)

No logic change, just polish the docstrings, comments and remove unused variables

Pull Request resolved: #165039
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542, #164500, #163185
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 15, 2025
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…162542)

**Motivation**

Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.

**Candidate Approaches**

1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook.

   - Pros: This is similar to how we do TP.
   - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context manager is still required for splitting the inputs.

2. Provide a function wrapper.

   - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper.
   - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API.

**Summary**

~~This PR implements approach 2 and refactor the code in such a way that most code can be used by option approach 1, which will be introduced in another PR.~~

We changed this PR to implement option 1 as people like option 1 due to the consistency with the existing parallelisms. But this PR can also serve the foundation to implement option 2, which was the early version of this PR.

This PR also changes `create_cp_block_mask` logic since we now only focus on ModuleWrapper approach which doesn't require to hack the seq_len field in a BlockMask.

This PR also removes TorchFunctionMode dispatcher mode as it doesn't work well with SAC.

Pull Request resolved: pytorch#162542
Approved by: https://github.com/XilunWu
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…ytorch#164500)

`context_parallel()` being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to:

1. Wrap the attention op into an `nn.Module`
2. Lift any buffers that are not sequence agnostic to input

We can replace `context_parallel()` with two functional APIs: `_context_parallel_shard` and `_enable_context_parallel_dispatcher`.

Pull Request resolved: pytorch#164500
Approved by: https://github.com/XilunWu
ghstack dependencies: pytorch#162542
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…h#163185)

The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter.  While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd.

For the next step, we should explore how to interpolate the required communication based on the information from BlockMask.

Pull Request resolved: pytorch#163185
Approved by: https://github.com/XilunWu
ghstack dependencies: pytorch#162542, pytorch#164500
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
@github-actions github-actions bot deleted the gh/fegin/318/head branch November 13, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: context parallel PyTorch Context Parallel oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: context parallel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants