Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Oct 7, 2024

Stack from ghstack (oldest at bottom):

Summary

The follow up PR to: #137526. In this pr, we actually update the lowerings for the flex_attention backwards kernel to generate fused backward gradient calculations for any captured buffers that require grads.

We are doing this using tl.atomic_add to scatter the correct gradients into zeroed out buffer for any captured buffers that required grads. Added many test cases and found. Along the way found some masking bugs.

There are likely some performance cliffs here, specifically with D-types and on different GPUs. Planned to do this in a follow-up and profile the current strategy. We are explicitly choosing reduced memory over increased performance right now.

By using atomics, we do not need to realize a full attention scores matrix. However, this comes with two downsides. One, this is potentially slower in some cases, and two, the gradient calculation for any captured buffers is non-deterministic.

Worked Example

Lets do the case where you are reading from one bias that doesn't require grad and using this to index into another that does.

ScoreMod:

bias = torch.randn(
    params.seq_length,
    device=self.device,
    dtype=params.dtype,
    requires_grad=True,
)

offset = torch.randint(
    0,
    params.seq_length,
    (params.seq_length,),
    device=self.device,
)

def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[offset[q_idx]]
    

I am removing all but the new subgraph injected into the backwards:

    dsT = pT * (dpT - Di[None, :])
    # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
    grad_scores = (dsT)


    # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
    idx_b = off_z
    idx_h = off_hq
    idx_m = m
    idx_n = n
    scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
    tmp4 = (dsT).to(tl.float32)
    tl.atomic_add(out_ptr1 + (tl.broadcast_to(tl.load(in_ptr16 + idx_m), tmp4.shape)), tmp4, scatter_mask, sem='relaxed')

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Key points

  • We always accumulate to float 32 grad buffers regardless of the type in the forward. This is because we normally do all computation intra kernel w/ fp32 accumulation and we want the same behavior for atomic additions
  • We are currently restricted to 1 scatter in the kenrel. I have some ideas on fx rewrites that would remove this restrictions but for now have nice error message w/ work around and will leave as a follow up.
  • Will do more extensive performance/ memory profiling in a follow up.

Toy E2E example

I have a toy E2E training example PR in the gym for now: meta-pytorch/attention-gym#84
I plan to update to a realistic learnable bias before landing

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @Chillee @yanboliang @BoyuanFeng

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137452

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5508c30 with merge base 259a00b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg marked this pull request as draft October 7, 2024 23:10
@drisspg drisspg removed the request for review from zou3519 October 7, 2024 23:39
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 8, 2024
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 8, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 17, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 29, 2024
[ghstack-poisoned]
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 7, 2024
@Chillee
Copy link
Collaborator

Chillee commented Nov 23, 2024

!!

[ghstack-poisoned]
@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 24, 2024
@drisspg drisspg added rocm This tag is for PRs from ROCm team ciflow/rocm Trigger "default" config CI on ROCm and removed rocm This tag is for PRs from ROCm team labels Nov 24, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 25, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Nov 25, 2024
@drisspg
Copy link
Contributor Author

drisspg commented Nov 25, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…#137452)

# Summary

The follow up PR to: pytorch#137526.  In this pr, we actually update the lowerings for the flex_attention backwards kernel to generate fused backward gradient calculations for any captured buffers that require grads.

We are doing this using tl.atomic_add to scatter the correct gradients into zeroed out buffer for any captured buffers that required grads. Added many test cases and found.  Along the way found some masking bugs.

There are likely some performance cliffs here, specifically with D-types and on different GPUs. Planned to do this in a follow-up and profile the current strategy. We are explicitly choosing reduced memory over increased performance right now.

By using atomics, we do not need to realize a full attention scores matrix. However, this comes with two downsides. One, this is potentially slower in some cases, and two, the gradient calculation for any captured buffers is non-deterministic.

## Worked Example

Lets do the case where you are reading from one bias that doesn't require grad and using this to index into another that does.

ScoreMod:
```Python
bias = torch.randn(
    params.seq_length,
    device=self.device,
    dtype=params.dtype,
    requires_grad=True,
)

offset = torch.randint(
    0,
    params.seq_length,
    (params.seq_length,),
    device=self.device,
)

def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[offset[q_idx]]

```

I am removing all but the new subgraph injected into the backwards:

``` Python
    dsT = pT * (dpT - Di[None, :])
    # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
    grad_scores = (dsT)

    # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
    idx_b = off_z
    idx_h = off_hq
    idx_m = m
    idx_n = n
    scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
    tmp4 = (dsT).to(tl.float32)
    tl.atomic_add(out_ptr1 + (tl.broadcast_to(tl.load(in_ptr16 + idx_m), tmp4.shape)), tmp4, scatter_mask, sem='relaxed')

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
```
## Key points
* We always accumulate to float 32 grad buffers regardless of the type in the forward. This is because we normally do all computation intra kernel w/ fp32 accumulation and we want the same behavior for atomic additions
* We are currently restricted to 1 scatter in the kenrel. I have some ideas on fx rewrites that would remove this restrictions but for now have nice error message w/ work around and will leave as a follow up.
* Will do more extensive performance/ memory profiling in a follow up.

### Toy E2E example
I have a toy E2E training example PR in the gym for now: meta-pytorch/attention-gym#84
I plan to update to a realistic learnable bias before landing

Pull Request resolved: pytorch#137452
Approved by: https://github.com/Chillee
@github-actions github-actions bot deleted the gh/drisspg/61/head branch December 26, 2024 02:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: flex attention module: inductor topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants