Skip to content

Conversation

@SkBlaz
Copy link
Contributor

@SkBlaz SkBlaz commented Aug 17, 2024

According to the documentation, decay is a number in [0,1] range, i.e.

Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999.

An inspection of swa_utils.py indicates there are no checks for invalid values of decay. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers torch cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415).

Fixes #133772

@SkBlaz SkBlaz requested review from albanD and janeyx99 as code owners August 17, 2024 18:21
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133773

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit bdf0be1 with merge base 73fde0d (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Aug 17, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 20, 2024
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're looking to improve the UX of these functions, let's go all the way and expand the current documentation to include what decay should be used for in get_ema_multi_avg_fn and get_ema_avg_fn. Then, instead of using an assert, we can throw a ValueError if a bad decay was given as input, similar to how we handle bad inputs for optimizer constructors. https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L55

@albanD albanD removed their request for review August 21, 2024 21:40
@SkBlaz
Copy link
Contributor Author

SkBlaz commented Aug 25, 2024

Since we're looking to improve the UX of these functions, let's go all the way and expand the current documentation to include what decay should be used for in get_ema_multi_avg_fn and get_ema_avg_fn. Then, instead of using an assert, we can throw a ValueError if a bad decay was given as input, similar to how we handle bad inputs for optimizer constructors. https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L55

Great idea, will do that.

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 24, 2024
@janeyx99
Copy link
Contributor

@SkBlaz do you still plan on bringing this over the finish line?

@SkBlaz
Copy link
Contributor Author

SkBlaz commented Oct 24, 2024 via email

@SkBlaz SkBlaz requested a review from janeyx99 October 25, 2024 13:26
@SkBlaz
Copy link
Contributor Author

SkBlaz commented Oct 25, 2024

@janeyx99 please lmk if this is what you had in mind.

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 28, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 1, 3, macos-m1-stable), trunk / macos-py3-arm64 / test (default, 3, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@SkBlaz
Copy link
Contributor Author

SkBlaz commented Oct 31, 2024

@janeyx99 hi, the failing builds don't seem related to this PR. Any idea?

@janeyx99
Copy link
Contributor

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

SkBlaz and others added 4 commits October 31, 2024 15:20
According to the documentation, decay is a number in [0,1] range. An inspection of `swa_utils.py`  indicates there are no checks for invalid values of `decay`. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers `torch` cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415).
@pytorchmergebot
Copy link
Collaborator

Successfully rebased patch-1 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout patch-1 && git pull --rebase)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
According to the documentation, decay is a number in [0,1] range,[ i.e.](https://pytorch.org/docs/stable/optim.html)
```
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999.
```
An inspection of `swa_utils.py`  indicates there are no checks for invalid values of `decay`. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers `torch` cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415).

Fixes pytorch#133772

Pull Request resolved: pytorch#133773
Approved by: https://github.com/janeyx99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SWA optimizer decay boundary condition not checked

5 participants