Skip to content

Conversation

@jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Oct 5, 2023

Adam part of: #110506

TODO:

  • If this approach is validated as a good one, it an also be applied to all other optimizers which convert complex via list comprehensions
  • Unclear whether strategy in this PR (move conversion into _init_group) is better than alternate (convert list comprehensions into for loop with single branch)
    • I believe both are fine, but the _init_group method seems more elegant, as it reduces duplicate code in single and multi_tensor cases. It's also faster.

Results:

NUM_PARAMS=200, foreach=True

  • main: dynamo: 43s, inductor: 31s, total: 74s
  • this PR: dynamo: 3s, inductor: 31s, total: 34s (speedup: 40s, 2.17x)
  • alternate strategy: dynamo: 6s, inductor: 30s, total: 36s (speedup: 38s, 2.0x)

NUM_PARAMS=200, foreach=False

  • main: dynamo: 15s, inductor: 61s, total: 76s
  • this PR: dynamo: 16s, inductor: 61s, total: 77s

Results (With logs enabled):

Details `NUM_PARAMS=200, foreach=True,TORCH_LOGS=+dynamo,schedule` - main: dynamo: 48s, inductor: 32s, total: 80s - this PR: dynamo: 20s, inductor: 32s, total: 52s (speedup: 28s)

NUM_PARAMS=200, foreach=False,TORCH_LOGS=+dynamo,schedule

  • main: dynamo: 19s, inductor: 64s, total 84s
  • this PR: dynamo: 20s, inductor: 64s, total: 84s

Benchmark script

import time
import torch
from torch.optim import Adam, SGD, Adagrad, NAdam

optim_cls = Adam
NUM_PARAMS = 200
kwargs = { "lr": 0.01, "foreach": True }

torch._dynamo.reset()
# torch._inductor.metrics.reset()
input = torch.ones([10, 10], device="cuda:0")
model = torch.nn.Sequential(
    *[torch.nn.Linear(10, 10, device="cuda:0") for _ in range(NUM_PARAMS)]
)

model(input).sum().backward()
opt_compiled = optim_cls(model.parameters(), **kwargs)
compiled_step = torch.compile(opt_compiled.step)

with torch.set_grad_enabled(False):
    start_time = time.time()
    compiled_step()
    print("compile opt took: %s seconds", time.time() - start_time)

print(torch._dynamo.utils.compile_times())

Alternate Strategy

Use for loops rather than list comprehensions

# Handle complex parameters
for i in range(len(device_grads)):
    if torch.is_complex(device_params[i]):
        device_grads[i] = torch.view_as_real(device_grads[i])
        device_exp_avgs[i] = torch.view_as_real(device_exp_avgs[i])
        device_exp_avg_sqs[i] = torch.view_as_real(device_exp_avg_sqs[i])
        device_max_exp_avg_sqs[i] = torch.view_as_real(device_max_exp_avg_sqs[i])
        device_params[i] = torch.view_as_real(device_params[i])

CC: @janeyx99 @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 5, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110596

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 103fa72 with merge base cf1b494 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:o wow

This looks pretty good to me. Could you run the numbers a few times to have more statistically significant evidence?

@mlazos are there concerns with moving more things into _init_group?

@jon-chuang
Copy link
Collaborator Author

This looks pretty good to me. Could you run the numbers a few times to have more statistically significant evidence?

Yes, I've run multiple times, sometimes inductor times change (it's quite variable) but dynamo times are quite invariant and improve by the same amount.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I've discovered a problem--people call our functional APIs (adam, single_tensor_adam, multi_tensor_adam) and expect complex tensors to work. In this PR, that would no longer work :/

It's great you identified these as slow bottlenecks though--we should work on getting dynamo to skip compiling this somehow.

@jon-chuang
Copy link
Collaborator Author

Ah I've discovered a problem--people call our functional APIs (adam, single_tensor_adam, multi_tensor_adam) and expect complex tensors to work. In this PR, that would no longer work :/

Right, that makes sense. In this case, we can do the conversion in loop using the alternate strategy, which is also faster.

@jon-chuang
Copy link
Collaborator Author

Superceded by: #110607

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants