-
Notifications
You must be signed in to change notification settings - Fork 26.3k
While loop autograd #124573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
While loop autograd #124573
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124573
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| self, | ||
| cond_fn: Callable, | ||
| body_fn: Callable, | ||
| body_grad_fn: Callable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, the backward of while_loop should be another while_loop operator with the same function signature cond_fn, body_fn, operands.
To get a backward formula for backward cond and backward body, consider the following example.
- forward: Inp0 -> body_fn -> inp1 -> body_fn -> … inpN-1 -> body_fn -> output.
- corresponding backward would be: grad_output, output, inpN-1-> backward_body_fn -> grad_inpN-1, inpN-1, , inpN-2 -> backward_body_fn -> …inp1, grad_inp1, inp0 -> backward_body_fn -> grad_inp0
Specifically, one possible design could be:
def backward_cond_fn(grad, fw_outputs(0…N)):
return fw_outputs.size() > 1def backward_body_fn(grad: Tensors, fw_outputs: TensorList):
Output = fw_outputs.pop() # (0…N-1)
InpN-1 = fw_outputs.back()
# do a re-computation based on inpN-1 since we didn't save the intermediates of each iteration.
# we could extend this by saving the important intermediates when it's necessary.
grad_N-1 = fw_bw(grad, output, inpN-1,)
Return gradN-1, fw_outputs #(0…N-1)The backward is then:
while_loop(backward_cond_fn, backward_body_fn, (grad_out, fw_outputs))
This might require us to support a dynamic list with unspecialized length cc @zou3519
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to follow your suggestion in my latest commit.
|
I just leave this here for future developments on this:
|
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
@pytorchbot label no-stale |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot label no-stale |
b7aefd8 to
eaf9d31
Compare
This PR needs a
|
eaf9d31 to
01bb249
Compare
This PR is an attempt to add Autograd to the
while_loopfunctionality of PT.@ydwu4
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec