Conversation
orttraining/orttraining/test/training_ops/cuda/optimizer/adam_test.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/training_ops/cuda/optimizer/adam/adam.h
Outdated
Show resolved
Hide resolved
orttraining/orttraining/training_ops/cuda/optimizer/adam/adam.cc
Outdated
Show resolved
Hide resolved
onnxruntime/test/testdata/test_data_generator/adamw_test_data_generator.py
Outdated
Show resolved
Hide resolved
onnxruntime/test/testdata/test_data_generator/adamw_test_data_generator.py
Outdated
Show resolved
Hide resolved
onnxruntime/test/testdata/test_data_generator/adamw_test_data_generator.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/test/training_ops/cuda/optimizer/adamw_test.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/test/training_ops/cuda/optimizer/adamw_test.cc
Outdated
Show resolved
Hide resolved
orttraining/orttraining/training_ops/cuda/optimizer/adamw/adamw.cc
Outdated
Show resolved
Hide resolved
| // > there is a minor difference compared with Apex's implementation, | ||
| // which uses double storing corrections before casting to float passing to kernels. | ||
| // > std::pow(float, int) return double since C++11, so we cast back to float. | ||
| alpha_correction = 1.f - static_cast<float>(std::pow(alpha, update_count)); |
There was a problem hiding this comment.
will we get better precision by casting to float after the subtraction?
std::pow(alpha, update_count) is a small number < 1, so a precision loss will affect it much more than (1-<a small number)>) will, same for beta correction. Please correct me if I'm wrong
There was a problem hiding this comment.
Yeah, it might be. While as the comment says, it is to match what we see from other frameworks.
I original use double to calculate them, then changed this from double to float, to avoid for some rare case, that might bring differences between our runs and torch/HF runs. Make sense?
There was a problem hiding this comment.
nit: The comment is not clear if we are matching apex or diverging from it, so if you have other things to fix, you can make the comment a bit clearer too.
Description: MTA (Multiple Tensor Apply) Adam Optimizer Implementation.
Motivation and Context