Skip to content

Conversation

@qqaatw
Copy link
Collaborator

@qqaatw qqaatw commented May 27, 2024

Summary:

This PR adds fused Adam and AdamW implementations.

Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
Fast math enabled:

[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100    
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        89    
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        90    
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        83    
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       12      |        94    
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       11      |        88    
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       12      |        90    
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |       100    
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       27      |       100    
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       23      |       100    
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       27      |       100    
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       23      |        98    
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       82      |       480    
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       72      |       450    
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       82      |       450    
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       73      |       420    
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       91      |       500    
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       83      |       400    
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |       94      |       500    
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       78      |       400    
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      170      |       500    
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      140      |       600    
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      170      |       600    
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      140      |       500    
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      250      |       890    
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      220      |       850    
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      250      |       830    
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      220      |       770    
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      270      |       870    
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      230      |       840    
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      270      |       810    
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      240      |       800    
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      400      |      1000    
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      360      |      2000    
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      430      |      2000    
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      360      |      1300    

Times are in milliseconds (ms).

Fast math disabled:

[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100    
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        84    
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        84    
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        79    
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       11      |        93    
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       10      |        90    
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       11      |        91    
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |        81    
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       34      |       100    
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       31      |       100    
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       34      |        95    
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       31      |       100    
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       94      |       500    
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       82      |       430    
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       92      |       430    
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       81      |       390    
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       98      |       500    
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       88      |       430    
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |      100      |       500    
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       88      |       400    
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      210      |       500    
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      190      |       610    
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      210      |       510    
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      190      |       500    
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      300      |       900    
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      260      |       850    
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      295      |       900    
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      260      |       800    
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      320      |       910    
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      280      |       900    
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      320      |       900    
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      300      |       900    
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      500      |      2000    
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      480      |      2000    
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      540      |      1500    
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      480      |      1200    

Times are in milliseconds (ms).
def profile_fused_adam():
    from torch.optim import adam, adamw
    import torch.utils.benchmark as benchmark

    import itertools


    def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
        fn(
            params,
            grads,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            state_steps,
            foreach=False,
            capturable=False,
            fused=fused,
            amsgrad=amsgrad,
            beta1=0.9,
            beta2=0.99,
            lr=1e-3,
            weight_decay=.0,
            eps=1e-5,
            maximize=False,
            grad_scale=None,
            found_inf=None,
        )
        torch.mps.synchronize()

    device = "mps"
    
    results = []

    for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
        print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
        params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
        max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
        state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
        if adamWflag:
            fn = adamw.adamw
        else:
            fn = adam.adam

        for fused in [True, False]:

            t = benchmark.Timer(
                    stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                    label='Fused Adam',
                    sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                    globals=locals(),
                    description= f"Fused: {fused}",
                ).blocked_autorange(min_run_time=5)
            results.append(t)

    compare = benchmark.Compare(results)
    compare.trim_significant_figures()
    compare.colorize(rowwise=True)
    compare.print()

cc @kulinseth @albanD @malfet

@pytorch-bot
Copy link

pytorch-bot bot commented May 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127242

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 3eedd99 with merge base 6e43897 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels May 27, 2024
@qqaatw qqaatw added the ciflow/trunk Trigger trunk jobs on your pull request label May 27, 2024
@qqaatw qqaatw force-pushed the multi_tensor_apply_mm branch from 8fbb72f to b6be4ea Compare May 30, 2024 11:52
@qqaatw qqaatw marked this pull request as ready for review May 30, 2024 12:57
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 30, 2024
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Letting @kulinseth review the kernels; I left a review on the testing portion.

Thanks for taking this on!

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, lets just make the metal functions precise.

{_differentiable_doc}
{_fused_doc}
.. Note::
A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janeyx99 do we usually document which backend supports fused?
Taking one more step, should we have, in the optimizer landing page in the doc, an overview table of which implementation we have and their stability level so that users can see at a glance what their options are?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A centralized landing page is a good idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it in a follow up, or do you prefer doing it in this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A followup is fine!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to make a PR for the table. I'd like to discuss the criteria that we regard an implementation as stable. If there is no existing standard, my initial proposal would be a 6 month period after the commit to regard it as stable. As a result, Adam and AdamW on CUDA would be considered as stable, while the rest are of beta. @janeyx99 What do you think?

id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();

float lr_lv = lr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is downcasting everything from double to float here ok?
This can lead to measurable differences in general but I would argue can be catastrophic for the eps below that might become a real 0. if we had a double-range epsilon.

cc @kulinseth how is this handled in other places?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we handle it the same way throughout the mps codebase since mps does not support double.

For example:

MPSGraphTensor* outputTensor = [mpsGraph normalizationWithTensor:inputTensor
meanTensor:saveMeanTensor
varianceTensor:varTensor
gammaTensor:weightTensor
betaTensor:biasTensor
epsilon:(float)epsilon
name:nil];

Comment on lines +339 to +341
MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){}
MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow users to specify compile options.

@qqaatw qqaatw requested a review from albanD June 14, 2024 10:08
@qqaatw
Copy link
Collaborator Author

qqaatw commented Jun 17, 2024

@pytorchbot merge -r strict

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 17, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot merge: error: argument -r/--rebase: invalid choice: 'strict' (choose from 'viable/strict', 'main')

usage: @pytorchbot merge [-f MESSAGE | -i] [-ic] [-r [{viable/strict,main}]]

Try @pytorchbot --help for more info.

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jun 17, 2024

@pytorchbot merge -r viable/strict

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/127242/head returned non-zero exit code 1

Rebasing (1/15)
Auto-merging aten/src/ATen/native/native_functions.yaml
Auto-merging test/test_optim.py
CONFLICT (content): Merge conflict in test/test_optim.py
Auto-merging torch/optim/adam.py
Auto-merging torch/optim/adamw.py
Auto-merging torch/optim/optimizer.py
Auto-merging torch/testing/_internal/common_optimizers.py
Auto-merging torch/utils/_foreach_utils.py
error: could not apply 79496ce336b... [MPS] Fused Adam & AdamW
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Could not apply 79496ce336b... [MPS] Fused Adam & AdamW

Raised by https://github.com/pytorch/pytorch/actions/runs/9556033956

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jun 18, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jun 18, 2024

macos-py3-arm64 / build in trunk keeps getting stuck at the build step for unknown reason.

@qqaatw
Copy link
Collaborator Author

qqaatw commented Jun 18, 2024

@pytorchbot merge -f "failure unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Comment on lines +3 to +4
#include <ATen/mps/MPSProfiler.h>
#include <Aten/native/mps/operations/FusedOptimizerOps.h>
Copy link
Contributor

@malfet malfet Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those should have been caps, quick fix is coming: #129474

pytorchmergebot added a commit that referenced this pull request Nov 21, 2024
pytorchmergebot pushed a commit that referenced this pull request Nov 22, 2024
For MacOS14+

Running following script (adapted from one mentioned in #127242 )
```python
import torch
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools

def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
    fn(
        params,
        grads,
        exp_avgs,
        exp_avg_sqs,
        max_exp_avg_sqs,
        state_steps,
        foreach=False,
        capturable=False,
        fused=fused,
        amsgrad=amsgrad,
        beta1=0.9,
        beta2=0.99,
        lr=1e-3,
        weight_decay=.0,
        eps=1e-5,
        maximize=False,
        grad_scale=None,
        found_inf=None,
    )
    torch.mps.synchronize()

device, dtype = "mps", torch.bfloat16

results = []

for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]):
    print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
    params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
    max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else []
    state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)]
    fn = adamw.adamw if adamWflag else adam.adam

    for fused in [True, False]:

        t = benchmark.Timer(
                stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                label=f'Fused Adam on {device} using {dtype}',
                sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                globals=locals(),
                description= f"Fused: {fused}",
            ).blocked_autorange(min_run_time=5)
        results.append(t)

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: #141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #141089, #141090, #141092, #141103
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
For MacOS14+

Running following script (adapted from one mentioned in pytorch#127242 )
```python
import torch
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools

def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
    fn(
        params,
        grads,
        exp_avgs,
        exp_avg_sqs,
        max_exp_avg_sqs,
        state_steps,
        foreach=False,
        capturable=False,
        fused=fused,
        amsgrad=amsgrad,
        beta1=0.9,
        beta2=0.99,
        lr=1e-3,
        weight_decay=.0,
        eps=1e-5,
        maximize=False,
        grad_scale=None,
        found_inf=None,
    )
    torch.mps.synchronize()

device, dtype = "mps", torch.bfloat16

results = []

for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]):
    print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
    params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
    max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else []
    state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)]
    fn = adamw.adamw if adamWflag else adam.adam

    for fused in [True, False]:

        t = benchmark.Timer(
                stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                label=f'Fused Adam on {device} using {dtype}',
                sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                globals=locals(),
                description= f"Fused: {fused}",
            ).blocked_autorange(min_run_time=5)
        results.append(t)

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: pytorch#141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants