-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[associative_scan] Autograd separated #139939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139939
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 4 PendingAs of commit 393a3bd with merge base ada43ed ( CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "topic: not user facing" |
|
Any review on this? |
|
Thanks for your implementation! I have a question regarding the shape check for assert x.shape == shape, "All xs tensors must have the same shape"Why does it require the tensors to have exactly the same shape? In the JAX implementation, only the first dimension is required to match: num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): |
|
Hi @WeihanLikk Thank you for looking into this. I was slow in working on this over the holidays but will pick up steam again. You are right, I don't think that this is necessarily required. Just the scanned dimension needs to be identical for all |
| y_T = f(y_{T-1}, x_T) | ||
| The gradients of y_T with respect to the vector x are computed as: | ||
| dy_T / dx = dy_T/dx_1 + dy_T/dx_2 + ... + dy_T/dx_T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not understanding this expression:
dy_T / dx = dy_T/dx_1 + dy_T/dx_2 + ... + dy_T/dx_T
Is there some typo in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I guess this should be the gradient of the element y_T with respect to the vector of inputs, i.e., with respect to every element x_1, x_2, ..., x_T of the vector. This would give individual elements like [dy_T / dx_1, dy_T / dx_2, ... dy_T / dx_T] and to get the final gradient these elements are summed.
|
@garrett361 I've implemented a first version of the backward approach we discussed offline. The algorithm per se works, but there are still some things to sort out. In particular the lifted argument and partial gradient support. EDIT: Of course there is still further room to cleanup the code and to adjust the documentation. |
What does lifted argument mean? Partial gradient = when only some inputs |
In some cases, variables and other properties from the
Yes, that is correct. Same applies here as well. I know how to handle it, but I wanted to wait for the autograd rework. |
|
Currently using associative scan for a big research project related to linear attention. (or perhaps I should say, logarithmic attention 😉) Is there an expected timeline for autograd support to be available? Really excited about this PR!! |
This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.
|
@huydhn I wouldn't know of this failure, but it could be of course |
|
@pytorchbot merge -f 'Incorrect revert, this change is not related to the trunk failure' |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: PR #139939 has not been reviewed yet |
|
@pytorchbot merge -f 'Incorrect revert, this change is not related to the trunk failure' |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: pytorch#139939 Approved by: https://github.com/ydwu4
This reverts commit 103f725. Reverted pytorch#139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](pytorch#139939 (comment)))
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: pytorch#139939 Approved by: https://github.com/huydhn
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: pytorch#139939 Approved by: https://github.com/ydwu4
This reverts commit 103f725. Reverted pytorch#139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](pytorch#139939 (comment)))
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: pytorch#139939 Approved by: https://github.com/huydhn
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: pytorch#139939 Approved by: https://github.com/ydwu4
This reverts commit 103f725. Reverted pytorch#139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](pytorch#139939 (comment)))
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: pytorch#139939 Approved by: https://github.com/huydhn
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: pytorch#139939 Approved by: https://github.com/ydwu4
This reverts commit 103f725. Reverted pytorch#139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](pytorch#139939 (comment)))
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: pytorch#139939 Approved by: https://github.com/huydhn
This PR implements the Autograd feature of the associative_scan.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela @yf225 @ColinPeppler @desertfire @ydwu4