-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[NAdam] Add capturable API and tests + fix differentiable #106615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106615
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b2558de: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
| mu_product_next = mu_product * mu_next | ||
| grad = grad * (-lr * (1. - mu) / (1. - mu_product)) | ||
| exp_avg = grad * (-lr * (1. - mu_next) / (1. - mu_product_next)) | ||
| exp_avg = exp_avg * (-lr * mu_next / (1. - mu_product_next)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was actually just incorrect before. It is on my mind to add differentiable correctness tests as a part of the test revamp that I will get to one day.
| if len(state) == 0: | ||
| state['step'] = torch.tensor(0.) | ||
| state['mu_product'] = torch.tensor(1.) | ||
| # note(crcrpar): [special device hosting for step] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is note(crcrpar) typo or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no it's copied from here: https://github.com/pytorch/pytorch/blob/e35cb480f4df1cf440b8705c93546c1b15891a4b/torch/optim/adam.py#L88C1-L90C82
The same reasoning applies here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha we can remove the note if you want your authorship to not be attached there @crcrpar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was wondering if sth like see note - [special device hosting for step] happens to be a little bit clearer. I just felt it a bit surprising to see my ID as a comment author while the comment itself is "author"ed by someone else
| if optimizer_constructor.__name__ == "NAdam": | ||
| # with capturable in NAdam, we have 3 extra intermediates for the | ||
| # bias_correction, mus, and mu_nexts | ||
| nintermediates = 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, the above comment has 2 extra leading to 3. And this one has 3 extra leading to 5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nadam needs 2 intermediates to start
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho ok. The comments are just a bit confusing then :D
| if capturable: | ||
| step = step_t | ||
| else: | ||
| step = _get_value(step_t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho, should _get_value handle capturable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One followup from using capturable is that the noncapturable path should not use _get_value or any PT2 related differing code (there's also _dispatch_sqrt). I was thinking of doing that in another PR in case this assumption that the PT2 path should always go in capturable impl does not actually hold. ccing @mlazos who may have better context here!
| denom = torch._foreach_sub(grouped_mu_products, 1.0) | ||
| torch._foreach_neg_(denom) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If only we had torch._foreach_rsub that matches the existing torch.rsub() :D
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All sounds good!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@janeyx99 It looks like |
|
Yes working on a fix rn |
Forward fixes #106615 by increasing tolerance in the test. The capturable implementation for foreach simply varies due to a different order of operations when updating params. I had also attempted to compare against fp64 but that introduced more disparity in the other optimizer configs. It is worth trying the fp64 comparison at a later point, but let's get the test passing first. Pull Request resolved: #106887 Approved by: https://github.com/izaitsevfb
…6615) This PR: - adds a capturable API for NAdam similar to Adam(W) - adds tests accordingly - discovered and fixed bugs in the differentiable implementation (now tested through the capturable codepath). Pull Request resolved: pytorch#106615 Approved by: https://github.com/albanD
This PR:
cc @mlazos -- once this lands you should be able to build on top of this implementation.
Stack from ghstack (oldest at bottom):