Skip to content

Conversation

@bohnstingl
Copy link
Collaborator

@bohnstingl bohnstingl commented Apr 21, 2024

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124573

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

self,
cond_fn: Callable,
body_fn: Callable,
body_grad_fn: Callable,
Copy link
Contributor

@ydwu4 ydwu4 Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, the backward of while_loop should be another while_loop operator with the same function signature cond_fn, body_fn, operands.

To get a backward formula for backward cond and backward body, consider the following example.

  • forward: Inp0 -> body_fn -> inp1 -> body_fn -> … inpN-1 -> body_fn -> output.
  • corresponding backward would be: grad_output, output, inpN-1-> backward_body_fn -> grad_inpN-1, inpN-1, , inpN-2 -> backward_body_fn -> …inp1, grad_inp1, inp0 -> backward_body_fn -> grad_inp0

Specifically, one possible design could be:

def backward_cond_fn(grad, fw_outputs(0N)):
  return fw_outputs.size() > 1
def backward_body_fn(grad: Tensors, fw_outputs: TensorList):
  Output = fw_outputs.pop() # (0…N-1)
  InpN-1 = fw_outputs.back()
  # do a re-computation based on inpN-1 since we didn't save the intermediates of each iteration.
  # we could extend this by saving the important intermediates when it's necessary.
  grad_N-1 = fw_bw(grad, output, inpN-1,)
  Return gradN-1, fw_outputs #(0…N-1)

The backward is then:

while_loop(backward_cond_fn, backward_body_fn, (grad_out, fw_outputs))

This might require us to support a dynamic list with unspecialized length cc @zou3519

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to follow your suggestion in my latest commit.

@bohnstingl
Copy link
Collaborator Author

I just leave this here for future developments on this:

  • I found some usecases where it is important to receive all intermediate outputs that are produced during the while loop operation. Do you think we could add a flag to indicate that?
  • Another usecase is that instead of just providing the initial input for the while loop with carried_inputs, it would be important to be able to provide "temporal inputs". This could be a only part of the carried_inputs as well. The rationale is that for RNN-like structure, one may want to feed a temporal input sequence, along with the initial states as inputs.

@github-actions
Copy link
Contributor

github-actions bot commented Sep 9, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Sep 9, 2024
@bohnstingl
Copy link
Collaborator Author

bohnstingl commented Sep 23, 2024

@pytorchbot label no-stale

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 23, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'no-stale' (choose from 'merge', 'revert', 'rebase', 'label', 'drci', 'cherry-pick', 'close')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick,close} ...

Try @pytorchbot --help for more info.

@bohnstingl
Copy link
Collaborator Author

@pytorchbot label no-stale

@pytorch-bot pytorch-bot bot added the no-stale label Sep 23, 2024
@bohnstingl bohnstingl force-pushed the while_loop_autograd branch from b7aefd8 to eaf9d31 Compare May 11, 2025 16:47
@github-actions
Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@bohnstingl bohnstingl closed this May 11, 2025
@bohnstingl bohnstingl force-pushed the while_loop_autograd branch from eaf9d31 to 01bb249 Compare May 11, 2025 16:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants