Skip to content

Conversation

@jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Oct 5, 2023

Adam part of: #110506

TODO:

  • If this approach is validated as a good one, it an also be applied to all other optimizers which convert complex via list comprehensions

Results:

NUM_PARAMS=200, foreach=True

  • main: dynamo: 43s, inductor: 31s, total: 74s
  • this PR: dynamo: 3.5s, inductor: 30s, total: 34s (dynamo speedup: 12.3x, overall speedup: 34s, 2.1x)

NUM_PARAMS=1000, foreach=True, has_complex shortcut:

<class 'torch.optim.adam.Adam'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics:
Function                              Runtimes (s)
------------------------------------  -------------------------------
_compile.<locals>.compile_inner       0.0329, 50.0806, 0.0041
OutputGraph.call_user_compiler        44.9924

NUM_PARAMS=1000, foreach=True, without has_complex shortcut:

<class 'torch.optim.adam.Adam'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics:
Function                              Runtimes (s)
------------------------------------  -------------------------------
_compile.<locals>.compile_inner       0.0389, 58.6069, 0.0043
OutputGraph.call_user_compiler        44.1425

Discussion

  • has_complex shortcut provides additional 2x dynamo speedup. It is not necessary to achieve a significant overall speedup.

CC: @janeyx99 @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 5, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110607

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit c63daf9 with merge base 11b3210 (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jon-chuang jon-chuang changed the title perf(inductor): improve Adam compile times drastically by changing list comprehensions into for loops perf(inductor): improve Adam compile times by shortcutting for loops (has_complex and is_complex for all params) Oct 5, 2023
@jon-chuang jon-chuang changed the title perf(inductor): improve Adam compile times by shortcutting for loops (has_complex and is_complex for all params) perf(inductor): improve Adam compile times by shortcutting for loops (has_complex and single is_complex) Oct 5, 2023
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd want to know the numbers detangling the has_complex heuristic + shortcutting the loops. Your current numbers combines both, yes?

Or is it like this PR and the numbers refer to only the shortcut change?

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Oct 5, 2023

I'd want to know the numbers detangling the has_complex heuristic + shortcutting the loops.

The disentangled numbers are in the PR description here:
image

As we can see, with has_complex flag, dynamo part of the compilation is > 2x faster

The middle number in the first line is the total time, the user_compiler stuff is inductor.

@albanD albanD removed their request for review October 5, 2023 18:23
@janeyx99
Copy link
Contributor

janeyx99 commented Oct 5, 2023

Ah I wasn't sure if the "has_complex" referred to the combination or just the has_complex flag, but I should also get better at reading 😅 A 2X dynamo speedup for num_params=1k seems worth doing, especially when I expect the majority of use cases to not use complex numbers.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to land after the shortcut lands

@jon-chuang jon-chuang changed the title perf(inductor): improve Adam compile times by shortcutting for loops (has_complex and single is_complex) perf(inductor): improve Adam compile times by shortcutting for loops (via has_complex) Oct 5, 2023
pytorchmergebot pushed a commit that referenced this pull request Oct 5, 2023
… against list comprehensions (e.g. complex conversion) (#110613)

Fully fixes: #110506

Depends: #110607
Potential merge conflicts:
- #110339
- #110345
- #110454

Related:
- #110606 (we can apply the improvements here orthogonally to the complex support)

### Results

Benchmark: 100 params.

Breakdowns (float32, dynamo):
```
Adagrad: this PR: 4.4s, main: 8.8s
Adam: this PR: 2.1s, main: 9.8s
AdamW: this PR: 2.5s, main: 8.2s
ASGD: this PR: 3.1s, main: 8.5s
RMSProp: this PR: 1.3s, main: 4.2s
RProp: this PR: 6.7s, main: 14.9s
```

Notes:
1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path
2. Adamax is not actually compiled (it is currently disabled).
3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing.

<details>

This PR:
```
Adagrad (torch.float32): 28.47496461868286s
Adagrad (torch.complex64): 29.379547357559204s
Adam (torch.float32): 17.334211587905884s
Adam (torch.complex64): 29.637500524520874s
Adamax (torch.float32): 2.4749321937561035s
Adamax (torch.complex64): 3.1997995376586914s
AdamW (torch.float32): 18.06532859802246s
AdamW (torch.complex64): 28.25661015510559s
ASGD (torch.float32): 23.70255398750305s
ASGD (torch.complex64): 25.33756995201111s
RMSprop (torch.float32): 7.964028596878052s
RMSprop (torch.complex64): 12.909599781036377s
Rprop (torch.float32): 30.512362003326416s
Rprop (torch.complex64): 44.74405765533447s
```

Main
```
Adagrad (torch.float32): 26.919506072998047s
Adagrad (torch.complex64): 35.190622091293335s
Adam (torch.float32): 25.715000867843628s
Adam (torch.complex64): 24.17716670036316s
Adamax (torch.float32): 2.4404726028442383s
Adamax (torch.complex64): 3.3538928031921387s
AdamW (torch.float32): 25.2022807598114s
AdamW (torch.complex64): 28.915700912475586s
ASGD (torch.float32): 24.108731985092163s
ASGD (torch.complex64): 26.589075088500977s
RMSprop (torch.float32): 10.781344175338745s
RMSprop (torch.complex64): 15.136352777481079s
Rprop (torch.float32): 42.46482181549072s
Rprop (torch.complex64): 48.28277635574341s
```

Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs).

</details>

### Benchmark Script
```python
import torch
import time
from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop

OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop]
DTYPES = [torch.float, torch.cfloat]

NUM_PARAMS = 100
kwargs = { "lr": 0.01, "foreach": True }
summary = []

for optim_cls in OPTIMS:
    for dtype in DTYPES:
        torch._dynamo.reset()
        # torch._inductor.metrics.reset()
        input = torch.ones([10, 10], dtype=dtype, device="cuda:0")
        model = torch.nn.Sequential(
            *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)]
        )

        model(input).sum().abs().backward()
        opt_compiled = optim_cls(model.parameters(), **kwargs)
        compiled_step = torch.compile(opt_compiled.step)

        with torch.set_grad_enabled(False):
            start_time = time.time()
            compiled_step()
            summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s")

        print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times())

for s in summary:
    print(s)
```

CC: @janeyx99 @mlazos
Pull Request resolved: #110613
Approved by: https://github.com/janeyx99
@jon-chuang
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 6, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2023
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants