Skip to content

MTA AdamWOptimizer#11506

Merged
pengwa merged 34 commits intomasterfrom
pengwa/adamv2
May 27, 2022
Merged

MTA AdamWOptimizer#11506
pengwa merged 34 commits intomasterfrom
pengwa/adamv2

Conversation

@pengwa
Copy link
Contributor

@pengwa pengwa commented May 12, 2022

Description: MTA (Multiple Tensor Apply) Adam Optimizer Implementation.

  • This is added by intention for supporting internal customers, can also be used for common training.
  • The implementation leverage Seq to manages groups of parameters/gradients/momentums, instead of using fixed length-ed variadic inputs (as we did for Lamb previously).
  • Multiple tensor apply is used.
  • AdamW equivalence for Torch AdamW and HF AdamW are provided, to allow models training with external libs easier to migration.

image

Motivation and Context

  • Why is this change required? What problem does it solve?
  • If it fixes an open issue, please link to the issue here.

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label May 12, 2022
@pengwa pengwa dismissed stale reviews from baijumeswani and askhade via f8b86d2 May 24, 2022 03:56
baijumeswani
baijumeswani previously approved these changes May 25, 2022
// > there is a minor difference compared with Apex's implementation,
// which uses double storing corrections before casting to float passing to kernels.
// > std::pow(float, int) return double since C++11, so we cast back to float.
alpha_correction = 1.f - static_cast<float>(std::pow(alpha, update_count));
Copy link
Contributor

@ashbhandare ashbhandare May 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we get better precision by casting to float after the subtraction?
std::pow(alpha, update_count) is a small number < 1, so a precision loss will affect it much more than (1-<a small number)>) will, same for beta correction. Please correct me if I'm wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it might be. While as the comment says, it is to match what we see from other frameworks.
I original use double to calculate them, then changed this from double to float, to avoid for some rare case, that might bring differences between our runs and torch/HF runs. Make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The comment is not clear if we are matching apex or diverging from it, so if you have other things to fix, you can make the comment a bit clearer too.

@pengwa pengwa changed the title MTA Adam Optimizer MTA AdamWOptimizer May 26, 2022
Copy link
Contributor

@ashbhandare ashbhandare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pengwa pengwa merged commit 44f7b1b into master May 27, 2022
@pengwa pengwa deleted the pengwa/adamv2 branch May 27, 2022 11:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants