Skip to content

Conversation

@janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Aug 4, 2023

This PR:

  • adds a capturable API for NAdam similar to Adam(W)
  • adds tests accordingly
  • discovered and fixed bugs in the differentiable implementation (now tested through the capturable codepath).

cc @mlazos -- once this lands you should be able to build on top of this implementation.

Stack from ghstack (oldest at bottom):

@janeyx99 janeyx99 requested a review from albanD as a code owner August 4, 2023 15:18
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 4, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106615

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b2558de:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: optimizer Relating to optimizers, torch.optim label Aug 4, 2023
@janeyx99 janeyx99 marked this pull request as draft August 4, 2023 15:20
@janeyx99 janeyx99 changed the title [NAdam] Add capturable API and tests [WIP][NAdam] Add capturable API and tests Aug 4, 2023
janeyx99 added a commit that referenced this pull request Aug 5, 2023
ghstack-source-id: 98bc871
Pull Request resolved: #106615
@janeyx99 janeyx99 changed the title [WIP][NAdam] Add capturable API and tests [NAdam] Add capturable API and tests + fix differentiable Aug 5, 2023
@janeyx99 janeyx99 added topic: new features topic category topic: bug fixes topic category labels Aug 5, 2023
mu_product_next = mu_product * mu_next
grad = grad * (-lr * (1. - mu) / (1. - mu_product))
exp_avg = grad * (-lr * (1. - mu_next) / (1. - mu_product_next))
exp_avg = exp_avg * (-lr * mu_next / (1. - mu_product_next))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually just incorrect before. It is on my mind to add differentiable correctness tests as a part of the test revamp that I will get to one day.

@janeyx99 janeyx99 marked this pull request as ready for review August 5, 2023 01:47
@janeyx99 janeyx99 requested a review from mlazos August 5, 2023 01:51
if len(state) == 0:
state['step'] = torch.tensor(0.)
state['mu_product'] = torch.tensor(1.)
# note(crcrpar): [special device hosting for step]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is note(crcrpar) typo or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha we can remove the note if you want your authorship to not be attached there @crcrpar

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was wondering if sth like see note - [special device hosting for step] happens to be a little bit clearer. I just felt it a bit surprising to see my ID as a comment author while the comment itself is "author"ed by someone else

if optimizer_constructor.__name__ == "NAdam":
# with capturable in NAdam, we have 3 extra intermediates for the
# bias_correction, mus, and mu_nexts
nintermediates = 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, the above comment has 2 extra leading to 3. And this one has 3 extra leading to 5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nadam needs 2 intermediates to start

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho ok. The comments are just a bit confusing then :D

if capturable:
step = step_t
else:
step = _get_value(step_t)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho, should _get_value handle capturable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One followup from using capturable is that the noncapturable path should not use _get_value or any PT2 related differing code (there's also _dispatch_sqrt). I was thinking of doing that in another PR in case this assumption that the PT2 path should always go in capturable impl does not actually hold. ccing @mlazos who may have better context here!

Comment on lines +409 to +410
denom = torch._foreach_sub(grouped_mu_products, 1.0)
torch._foreach_neg_(denom)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only we had torch._foreach_rsub that matches the existing torch.rsub() :D

@janeyx99 janeyx99 added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 7, 2023
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All sounds good!

@janeyx99
Copy link
Contributor Author

janeyx99 commented Aug 7, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@huydhn
Copy link
Contributor

huydhn commented Aug 9, 2023

@janeyx99 It looks like test_optim.py::TestOptim::test_multi_tensor_optimizers_with_varying_tensors is failing on multigpu with Tensor-likes are not close error after this change https://hud.pytorch.org/pytorch/pytorch/commit/0208574db95720a2569004114d323e922f46716d. Could you help take a look?

@janeyx99
Copy link
Contributor Author

janeyx99 commented Aug 9, 2023

Yes working on a fix rn

pytorchmergebot pushed a commit that referenced this pull request Aug 10, 2023
Forward fixes #106615 by increasing tolerance in the test.

The capturable implementation for foreach simply varies due to a different order of operations when updating params. I had also attempted to compare against fp64 but that introduced more disparity in the other optimizer configs. It is worth trying the fp64 comparison at a later point, but let's get the test passing first.

Pull Request resolved: #106887
Approved by: https://github.com/izaitsevfb
@facebook-github-bot facebook-github-bot deleted the gh/janeyx99/78/head branch August 11, 2023 14:17
Cyril-Anto pushed a commit to Cyril-Anto/pytorch that referenced this pull request Aug 17, 2023
…6615)

This PR:
- adds a capturable API for NAdam similar to Adam(W)
- adds tests accordingly
- discovered and fixed bugs in the differentiable implementation (now tested through the capturable codepath).

Pull Request resolved: pytorch#106615
Approved by: https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: optimizer Relating to optimizers, torch.optim topic: bug fixes topic category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants