-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[MPS] Fused Adam & AdamW #127242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Fused Adam & AdamW #127242
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127242
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 3eedd99 with merge base 6e43897 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
8fbb72f to
b6be4ea
Compare
janeyx99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Letting @kulinseth review the kernels; I left a review on the testing portion.
Thanks for taking this on!
kulinseth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me, lets just make the metal functions precise.
| {_differentiable_doc} | ||
| {_fused_doc} | ||
| .. Note:: | ||
| A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@janeyx99 do we usually document which backend supports fused?
Taking one more step, should we have, in the optimizer landing page in the doc, an overview table of which implementation we have and their stability level so that users can see at a glance what their options are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A centralized landing page is a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do it in a follow up, or do you prefer doing it in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A followup is fine!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to make a PR for the table. I'd like to discuss the criteria that we regard an implementation as stable. If there is no existing standard, my initial proposal would be a 6 month period after the commit to regard it as stable. As a result, Adam and AdamW on CUDA would be considered as stable, while the rest are of beta. @janeyx99 What do you think?
| id<MTLDevice> device = MPSDevice::getInstance()->device(); | ||
| MPSStream* mpsStream = getCurrentMPSStream(); | ||
|
|
||
| float lr_lv = lr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is downcasting everything from double to float here ok?
This can lead to measurable differences in general but I would argue can be catastrophic for the eps below that might become a real 0. if we had a double-range epsilon.
cc @kulinseth how is this handled in other places?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we handle it the same way throughout the mps codebase since mps does not support double.
For example:
pytorch/aten/src/ATen/native/mps/operations/Normalization.mm
Lines 271 to 278 in 773ae81
| MPSGraphTensor* outputTensor = [mpsGraph normalizationWithTensor:inputTensor | |
| meanTensor:saveMeanTensor | |
| varianceTensor:varTensor | |
| gammaTensor:weightTensor | |
| betaTensor:biasTensor | |
| epsilon:(float)epsilon | |
| name:nil]; | |
| MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){} | ||
| MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){} | ||
| MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allow users to specify compile options.
|
@pytorchbot merge -r strict |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot merge -r viable/strict |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Rebase failed due to Command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
macos-py3-arm64 / build in trunk keeps getting stuck at the build step for unknown reason. |
|
@pytorchbot merge -f "failure unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
| #include <ATen/mps/MPSProfiler.h> | ||
| #include <Aten/native/mps/operations/FusedOptimizerOps.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those should have been caps, quick fix is coming: #129474
…140725)" This reverts commit 9bc9d4c. Reverted #140725 on behalf of https://github.com/malfet due to It causes deadlocks when I try to run something benchmark from #127242 ([comment](#140725 (comment)))
For MacOS14+ Running following script (adapted from one mentioned in #127242 ) ```python import torch from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device, dtype = "mps", torch.bfloat16 results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)] fn = adamw.adamw if adamWflag else adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label=f'Fused Adam on {device} using {dtype}', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: #141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: #141089, #141090, #141092, #141103
…ytorch#140725)" This reverts commit 9bc9d4c. Reverted pytorch#140725 on behalf of https://github.com/malfet due to It causes deadlocks when I try to run something benchmark from pytorch#127242 ([comment](pytorch#140725 (comment)))
For MacOS14+ Running following script (adapted from one mentioned in pytorch#127242 ) ```python import torch from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device, dtype = "mps", torch.bfloat16 results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)] fn = adamw.adamw if adamWflag else adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label=f'Fused Adam on {device} using {dtype}', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: pytorch#141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
Summary:
This PR adds fused Adam and AdamW implementations.
Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
Fast math enabled:
Fast math disabled:
cc @kulinseth @albanD @malfet