Skip to content

Conversation

@bohnstingl
Copy link
Collaborator

@bohnstingl bohnstingl commented Nov 6, 2024

@bohnstingl bohnstingl requested a review from zou3519 as a code owner November 6, 2024 23:39
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139939

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 4 Pending

As of commit 393a3bd with merge base ada43ed (image):

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Nov 6, 2024
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 8, 2024
@zou3519 zou3519 requested review from ydwu4 and removed request for zou3519 November 11, 2024 17:59
@bohnstingl bohnstingl changed the title Improvements for associative_scan - Autograd separated [associative_scan] Autograd separated Nov 19, 2024
@bhack
Copy link
Contributor

bhack commented Dec 14, 2024

Any review on this?

@WeihanLikk
Copy link

Thanks for your implementation! I have a question regarding the shape check for xs:

assert x.shape == shape, "All xs tensors must have the same shape"

Why does it require the tensors to have exactly the same shape? In the JAX implementation, only the first dimension is required to match:

num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):

@bohnstingl
Copy link
Collaborator Author

Hi @WeihanLikk

Thank you for looking into this. I was slow in working on this over the holidays but will pick up steam again. You are right, I don't think that this is necessarily required. Just the scanned dimension needs to be identical for all xs. I will take a look

y_T = f(y_{T-1}, x_T)
The gradients of y_T with respect to the vector x are computed as:
dy_T / dx = dy_T/dx_1 + dy_T/dx_2 + ... + dy_T/dx_T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not understanding this expression:

dy_T / dx = dy_T/dx_1 + dy_T/dx_2 + ... + dy_T/dx_T

Is there some typo in here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I guess this should be the gradient of the element y_T with respect to the vector of inputs, i.e., with respect to every element x_1, x_2, ..., x_T of the vector. This would give individual elements like [dy_T / dx_1, dy_T / dx_2, ... dy_T / dx_T] and to get the final gradient these elements are summed.

@bohnstingl
Copy link
Collaborator Author

bohnstingl commented Mar 9, 2025

@garrett361 I've implemented a first version of the backward approach we discussed offline. The algorithm per se works, but there are still some things to sort out. In particular the lifted argument and partial gradient support.

EDIT: Of course there is still further room to cleanup the code and to adjust the documentation.

@garrett361
Copy link
Contributor

lifted argument and partial gradient support

What does lifted argument mean?

Partial gradient = when only some inputs require_grad?

@bohnstingl
Copy link
Collaborator Author

What does lifted argument mean?

In some cases, variables and other properties from the combine_fn are lifted as additional inputs. For example, these could be external variables, or symbolic shapes of tensors as well. In the following case

H = torch.rand(2, device=device)
def fct_freevars1(x: torch.Tensor, y: torch.Tensor):
    return x * H + y * 2

H would become a lifted variable. I know how to handle those, but I think @ydwu4 is currently working on simplifying the autograd architecture for higher order operators, such as associative scan, to simplify this handling.

Partial gradient = when only some inputs require_grad?

Yes, that is correct. Same applies here as well. I know how to handle it, but I wanted to wait for the autograd rework.

@windsornguyen
Copy link

Currently using associative scan for a big research project related to linear attention. (or perhaps I should say, logarithmic attention 😉)

Is there an expected timeline for autograd support to be available? Really excited about this PR!!

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Sep 8, 2025
@pytorch-bot pytorch-bot bot dismissed ydwu4’s stale review September 8, 2025 20:42

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@bohnstingl
Copy link
Collaborator Author

@huydhn I wouldn't know of this failure, but it could be of course

@huydhn
Copy link
Contributor

huydhn commented Sep 8, 2025

@pytorchbot merge -f 'Incorrect revert, this change is not related to the trunk failure'

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: PR #139939 has not been reviewed yet

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@huydhn
Copy link
Contributor

huydhn commented Sep 8, 2025

@pytorchbot merge -f 'Incorrect revert, this change is not related to the trunk failure'

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
This reverts commit 103f725.

Reverted pytorch#139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](pytorch#139939 (comment)))
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
This reverts commit 103f725.

Reverted pytorch#139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](pytorch#139939 (comment)))
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
This reverts commit 103f725.

Reverted pytorch#139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](pytorch#139939 (comment)))
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
This reverts commit 103f725.

Reverted pytorch#139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](pytorch#139939 (comment)))
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor open source Reverted Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.