-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FlexAttention] add support for learnable biases in Inductor #137452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137452
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5508c30 with merge base 259a00b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
!! |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…#137452) # Summary The follow up PR to: pytorch#137526. In this pr, we actually update the lowerings for the flex_attention backwards kernel to generate fused backward gradient calculations for any captured buffers that require grads. We are doing this using tl.atomic_add to scatter the correct gradients into zeroed out buffer for any captured buffers that required grads. Added many test cases and found. Along the way found some masking bugs. There are likely some performance cliffs here, specifically with D-types and on different GPUs. Planned to do this in a follow-up and profile the current strategy. We are explicitly choosing reduced memory over increased performance right now. By using atomics, we do not need to realize a full attention scores matrix. However, this comes with two downsides. One, this is potentially slower in some cases, and two, the gradient calculation for any captured buffers is non-deterministic. ## Worked Example Lets do the case where you are reading from one bias that doesn't require grad and using this to index into another that does. ScoreMod: ```Python bias = torch.randn( params.seq_length, device=self.device, dtype=params.dtype, requires_grad=True, ) offset = torch.randint( 0, params.seq_length, (params.seq_length,), device=self.device, ) def score_mod(score, b, h, q_idx, kv_idx): return score + bias[offset[q_idx]] ``` I am removing all but the new subgraph injected into the backwards: ``` Python dsT = pT * (dpT - Di[None, :]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ grad_scores = (dsT) # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ idx_b = off_z idx_h = off_hq idx_m = m idx_n = n scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN tmp4 = (dsT).to(tl.float32) tl.atomic_add(out_ptr1 + (tl.broadcast_to(tl.load(in_ptr16 + idx_m), tmp4.shape)), tmp4, scatter_mask, sem='relaxed') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ``` ## Key points * We always accumulate to float 32 grad buffers regardless of the type in the forward. This is because we normally do all computation intra kernel w/ fp32 accumulation and we want the same behavior for atomic additions * We are currently restricted to 1 scatter in the kenrel. I have some ideas on fx rewrites that would remove this restrictions but for now have nice error message w/ work around and will leave as a follow up. * Will do more extensive performance/ memory profiling in a follow up. ### Toy E2E example I have a toy E2E training example PR in the gym for now: meta-pytorch/attention-gym#84 I plan to update to a realistic learnable bias before landing Pull Request resolved: pytorch#137452 Approved by: https://github.com/Chillee
Stack from ghstack (oldest at bottom):
Summary
The follow up PR to: #137526. In this pr, we actually update the lowerings for the flex_attention backwards kernel to generate fused backward gradient calculations for any captured buffers that require grads.
We are doing this using tl.atomic_add to scatter the correct gradients into zeroed out buffer for any captured buffers that required grads. Added many test cases and found. Along the way found some masking bugs.
There are likely some performance cliffs here, specifically with D-types and on different GPUs. Planned to do this in a follow-up and profile the current strategy. We are explicitly choosing reduced memory over increased performance right now.
By using atomics, we do not need to realize a full attention scores matrix. However, this comes with two downsides. One, this is potentially slower in some cases, and two, the gradient calculation for any captured buffers is non-deterministic.
Worked Example
Lets do the case where you are reading from one bias that doesn't require grad and using this to index into another that does.
ScoreMod:
I am removing all but the new subgraph injected into the backwards:
Key points
Toy E2E example
I have a toy E2E training example PR in the gym for now: meta-pytorch/attention-gym#84
I plan to update to a realistic learnable bias before landing
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @Chillee @yanboliang @BoyuanFeng