-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[CP]Introduce ContextParallal plan for parallelize_module() #162542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162542
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a18e4eb with merge base d41aa18 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
**Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module. ghstack-source-id: 0142083 Pull-Request-resolved: #162542
**Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module. ghstack-source-id: 0355010 Pull-Request-resolved: #162542
**Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module. ghstack-source-id: 0355010 Pull-Request-resolved: #162542
This PR requires pytorch/pytorch#162542
**Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module. ghstack-source-id: ba158c5 Pull-Request-resolved: #162542
Similar to #1696, but this PR uses parallel_module similar to TP/SP. This PR also requires pytorch/pytorch#162542
**Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module. ghstack-source-id: 9f6f679 Pull-Request-resolved: #162542
**Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module. ghstack-source-id: 1658e02 Pull-Request-resolved: #162542
**Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module. ghstack-source-id: e51d35b Pull-Request-resolved: #162542
**Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context switch is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** This PR implements approach 2 but uses a nn.Module to mimic a function wrapper so that we can record CP state in the module instead of leaking it to the `_attention` module. ghstack-source-id: 39b417a Pull-Request-resolved: #162542
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 4, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu), trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 2, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu), trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 5, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…164500) `context_parallel()` being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to: 1. Wrap the attention op into an `nn.Module` 2. Lift any buffers that are not sequence agnostic to input We can replace `context_parallel()` with two functional APIs: `_context_parallel_shard` and `_enable_context_parallel_dispatcher`. Pull Request resolved: #164500 Approved by: https://github.com/XilunWu ghstack dependencies: #162542
| # TODO: Reverify atol and rtol after | ||
| # https://github.com/pytorch/pytorch/pull/163185 is landed. The accuracy | ||
| # issue happens on the gradients. | ||
| torch.use_deterministic_algorithms(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't do things like this. It might subtly break other tests since this globally changes the deterministic setting. Please only set deterministic on the tests you need it for, and use a context manager so that it resets back to prior state.
The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter. While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd. For the next step, we should explore how to interpolate the required communication based on the information from BlockMask. Pull Request resolved: #163185 Approved by: https://github.com/XilunWu ghstack dependencies: #162542, #164500
…orch#165039) No logic change, just polish the docstrings, comments and remove unused variables Pull Request resolved: pytorch#165039 Approved by: https://github.com/XilunWu ghstack dependencies: pytorch#162542, pytorch#164500, pytorch#163185
…162542) **Motivation** Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques. **Candidate Approaches** 1. Ask users to write a module to wrap FlexAttention/SDPA and use `parallelize_module` to install a forward hook. - Pros: This is similar to how we do TP. - Cons: 1) It is cumbersome for users as they need to create a new module. 2) We need two places to parallelize the CP, as a context_parallel context manager is still required for splitting the inputs. 2. Provide a function wrapper. - Pros: Users just need to replace their FlexAttention/SDPA calls with the wrapper. - Cons: It is not the same API, though we can maintain the API signatures to be the same as the core API. **Summary** ~~This PR implements approach 2 and refactor the code in such a way that most code can be used by option approach 1, which will be introduced in another PR.~~ We changed this PR to implement option 1 as people like option 1 due to the consistency with the existing parallelisms. But this PR can also serve the foundation to implement option 2, which was the early version of this PR. This PR also changes `create_cp_block_mask` logic since we now only focus on ModuleWrapper approach which doesn't require to hack the seq_len field in a BlockMask. This PR also removes TorchFunctionMode dispatcher mode as it doesn't work well with SAC. Pull Request resolved: pytorch#162542 Approved by: https://github.com/XilunWu
…ytorch#164500) `context_parallel()` being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to: 1. Wrap the attention op into an `nn.Module` 2. Lift any buffers that are not sequence agnostic to input We can replace `context_parallel()` with two functional APIs: `_context_parallel_shard` and `_enable_context_parallel_dispatcher`. Pull Request resolved: pytorch#164500 Approved by: https://github.com/XilunWu ghstack dependencies: pytorch#162542
…h#163185) The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter. While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd. For the next step, we should explore how to interpolate the required communication based on the information from BlockMask. Pull Request resolved: pytorch#163185 Approved by: https://github.com/XilunWu ghstack dependencies: pytorch#162542, pytorch#164500
…orch#165039) No logic change, just polish the docstrings, comments and remove unused variables Pull Request resolved: pytorch#165039 Approved by: https://github.com/XilunWu ghstack dependencies: pytorch#162542, pytorch#164500, pytorch#163185
Stack from ghstack (oldest at bottom):
Motivation
Since FlexAttention and SDPA are both functions, not modules, we have tried numerous mechanisms to dispatch FlexAttention and SDPA to customized call paths so that we can inject the CP logic. Unfortunately, all of these approaches have their own composability issues with different techniques.
Candidate Approaches
Ask users to write a module to wrap FlexAttention/SDPA and use
parallelize_moduleto install a forward hook.Provide a function wrapper.
Summary
This PR implements approach 2 and refactor the code in such a way that most code can be used by option approach 1, which will be introduced in another PR.We changed this PR to implement option 1 as people like option 1 due to the consistency with the existing parallelisms. But this PR can also serve the foundation to implement option 2, which was the early version of this PR.
This PR also changes
create_cp_block_masklogic since we now only focus on ModuleWrapper approach which doesn't require to hack the seq_len field in a BlockMask.This PR also removes TorchFunctionMode dispatcher mode as it doesn't work well with SAC.
cc @H-Huang @awgu @wanchaol @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci