Skip to content

Conversation

@pietern
Copy link
Contributor

@pietern pietern commented Jun 25, 2019

Reduction of gradients for unused parameters should happen as soon as
possible, because they potentially block reduction of gradients for
used parameters. This used to happen instantly when
prepare_for_backward was called and it found parameters that didn't
contribute. This meant that if you have a model with unused
parameters, and you want to discard the model output (i.e. not call
backward on some loss), reduction of the gradients of those unused
parameters would have been kicked off, and you'd see an error the next
time you called forward.

In this commit, this original approach is slightly changed to delay
reduction of the gradients of those unused parameters until the first
autograd hook is called. This means that you can now discard the model
output regardless of the model having unused parameters or not.

This is a prerequisite for making the find_unused_parameters
argument to DDP default to True.

@pietern pietern added oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 25, 2019
@pietern pietern requested review from apaszke and mrshenli as code owners June 25, 2019 19:10
Reduction of gradients for unused parameters should happen as soon as
possible, because they potentially block reduction of gradients for
used parameters. This used to happen instantly when
`prepare_for_backward` was called and it found parameters that didn't
contribute. This meant that if you have a model with unused
parameters, and you want to discard the model output (i.e. not call
backward on some loss), reduction of the gradients of those unused
parameters would have been kicked off, and you'd see an error the next
time you called `forward`.

In this commit, this original approach is slightly changed to delay
reduction of the gradients of those unused parameters until the first
autograd hook is called. This means that you can now discard the model
output regardless of the model having unused parameters or not.

This is a prerequisite for making the `find_unused_parameters`
argument to DDP default to `True`.
@pietern pietern force-pushed the reducer-reduce-on-backward branch from 392b3c8 to 216c384 Compare June 25, 2019 19:39
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It breaks test_no_used_parameters as no post hook will be called at all. Seems we need to add another special case for it when all params are unused, or is it possible to use the queue_callback to register it upfront?

@pietern
Copy link
Contributor Author

pietern commented Jun 26, 2019

Regarding test_no_used_parameters -- I think we should nuke it. I added it to test the corner case of find_unused_parameters=True, but it implies that you cannot discard the model output ever. This is not very practical if you want to only compute the grad through a DDP model instead of accumulating the grads w.r.t. the model parameters. If we want to make find_unused_parameters=True the default (per some ad hoc discussions between @soumith, you, and me), then we need to wait for a signal that you want to compute and reduce gradients by calling backward.

@pietern
Copy link
Contributor Author

pietern commented Jun 26, 2019

@pytorchbot retest this please

@pietern
Copy link
Contributor Author

pietern commented Jun 26, 2019

@pytorchbot retest this please

@pietern
Copy link
Contributor Author

pietern commented Jun 27, 2019

After checking in with CircleCI it is clear that the error for pytorch_linux_trusty_py3_6_gcc5_4_test is a false negative. There is a relationship between the first try and subsequent tries kickstarted by @pytorchbot that make this fail before even running the job.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pietern is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pietern has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@pietern merged this pull request in 7a40412.

@pietern pietern deleted the reducer-reduce-on-backward branch June 28, 2019 10:19
xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
…led (pytorch#22219)

Summary:
Reduction of gradients for unused parameters should happen as soon as
possible, because they potentially block reduction of gradients for
used parameters. This used to happen instantly when
`prepare_for_backward` was called and it found parameters that didn't
contribute. This meant that if you have a model with unused
parameters, and you want to discard the model output (i.e. not call
backward on some loss), reduction of the gradients of those unused
parameters would have been kicked off, and you'd see an error the next
time you called `forward`.

In this commit, this original approach is slightly changed to delay
reduction of the gradients of those unused parameters until the first
autograd hook is called. This means that you can now discard the model
output regardless of the model having unused parameters or not.

This is a prerequisite for making the `find_unused_parameters`
argument to DDP default to `True`.
Pull Request resolved: pytorch#22219

Differential Revision: D16028698

Pulled By: pietern

fbshipit-source-id: c6aec2cd39c4a77746495d9cb1c9fb9c5ac61983
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants