Skip to content

Conversation

@jon-chuang
Copy link
Collaborator

@jon-chuang jon-chuang commented Oct 1, 2023

Originally disabled in #105438

However, compilation times are plenty speedy to me:
image

Could have benefited from perf changes to scheduler like cycle-detection optimizations.

CC @mlazos as original code author

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110353

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit ebe7daf with merge base cf1b494 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jon-chuang jon-chuang changed the title feat(inductor): Add SGD back to inductor feat(inductor): Add SGD Optimizer back to Inductor Oct 1, 2023
@jon-chuang
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 1, 2023
@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Oct 2, 2023

Hmm, dlrm fails with:

Details
loading model: 0it [00:00, ?it/s]
loading model: 0it [00:08, ?it/s]
cuda train dlrm                               
ERROR:common:backend='compiler_fn' raised:
NotImplementedError: Cannot access storage of SparseTensorImpl

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/benchmarks/dynamo/common.py", line 2315, in check_accuracy
    new_result = optimized_model_iter_fn(model_copy, example_inputs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 401, in _fn
    return fn(*args, **kwargs)
  File "/var/lib/jenkins/workspace/benchmarks/dynamo/common.py", line 2079, in run_n_iterations
    self.model_iter_fn(mod, inputs, collect_outputs=False)
  File "/var/lib/jenkins/workspace/benchmarks/dynamo/torchbench.py", line 509, in forward_and_backward_pass
    cloned_inputs = clone_inputs(inputs)
  File "/var/lib/jenkins/workspace/benchmarks/dynamo/torchbench.py", line 510, in <resume in forward_and_backward_pass>
    self.optimizer_zero_grad(mod)
  File "/var/lib/jenkins/workspace/benchmarks/dynamo/torchbench.py", line 513, in <resume in forward_and_backward_pass>
    loss = self.compute_loss(pred)
  File "/var/lib/jenkins/workspace/benchmarks/dynamo/torchbench.py", line 514, in <resume in forward_and_backward_pass>
    self.grad_scaler.scale(loss).backward()
  File "/var/lib/jenkins/workspace/benchmarks/dynamo/torchbench.py", line 515, in <resume in forward_and_backward_pass>
    self.optimizer_step()
  File "/var/lib/jenkins/workspace/benchmarks/dynamo/common.py", line 2090, in optimizer_step
    self.optimizer.step()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 549, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 632, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 559, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 190, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 481, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 451, in transform
    tracer.run()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2103, in run
    super().run()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in run
    and self.step()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 706, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2213, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 883, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 985, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 190, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1052, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1037, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3894, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 190, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3432, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2215, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line [2390](https://github.com/pytorch/pytorch/actions/runs/6371187608/job/17293390761?pr=110353#step:15:2391), in aot_wrapper_synthetic_base
    flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1813, in merge_view_inputs
    storage_ref = StorageWeakRef(inpt.untyped_storage())
torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
NotImplementedError: Cannot access storage of SparseTensorImpl

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

TorchDynamo optimized model failed to run because of following error
fail_to_run
TIMING: entire_frame_compile:1.55729 backend_compile:0.6531
STATS: call_* op count: 80 | FakeTensor.__torch_dispatch__:440 | FakeTensorMode.__torch_dispatch__:3155 | ProxyTorchDispatchMode.__torch_dispatch__:1160
Dynamo produced 2 graphs covering 80 ops with 7 graph breaks (5 unique)

Seems like accessing sparse tensor storage generally has poor support: #108667 #106837 The issue does seem to be cropping up more frequently lately (3 in the last 2 months, including this PR).

@colesbury colesbury requested a review from mlazos October 3, 2023 23:01
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 3, 2023
@mlazos
Copy link
Contributor

mlazos commented Oct 4, 2023

@jon-chuang typically optimizers operate on hundreds of parameters in our benchmarks, I've been benchmarking them on models with ~1k parameters to provide good data (this number is more representative of what we have in our OSS benchmarks). It currently takes about 2 minutes to compile an optimizer with 1k parameters. This is on par with Jax, but there are a few improvements in flight which should improve this

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Oct 4, 2023

typically optimizers operate on hundreds of parameters in our benchmarks

Would you be able to share the benchmark results? I would like to investigate what might be causing the slow compile times.

Even a simple repro script would be rly great (I could also try to come up with a repro)

It currently takes about 2 minutes to compile an optimizer with 1k parameters.

I guess this is slower in the single (non foreach) case?

@jon-chuang
Copy link
Collaborator Author

jon-chuang commented Oct 4, 2023

Hmm, I investigated locally with the following script:

import time
import torch
from torch.optim import Adam, SGD

def compile_opt(opt_compiled):
    torch._dynamo.eval_frame.TorchPatcher.patch()

    step_fn = opt_compiled.step.__wrapped__
    def fn():
        step_fn(opt_compiled)

    return torch.compile(fn, backend="inductor", fullgraph=True)

optim_cls = SGD
NUM_PARAMS = 1000
kwargs = { "lr": 0.01, "foreach": True }

torch._dynamo.reset()
# torch._inductor.metrics.reset()
input = torch.ones([10, 10], device="cuda:0")
model = torch.nn.Sequential(
    *[torch.nn.Linear(10, 10, device="cuda:0") for _ in range(NUM_PARAMS)]
)

input = torch.ones([10, 10], device="cuda:0")
model(input).sum().backward()
opt_compiled = optim_cls(model.parameters(), **kwargs)
compiled_step = compile_opt(opt_compiled)

with torch.set_grad_enabled(False):
    start_time = time.time()
    compiled_step()
    print("compile opt took: %s seconds", time.time() - start_time)

Results

NUM_PARAMS = 200

  • Adam: 109 seconds, only 35 seconds was spent on inductor. It seems that dynamo symbolic_trace was the bottleneck here.
  • Adam foreach: 105 seconds, 35 seconds on inductor
  • SGD: 14 seconds, 7 seconds spent on inductor.
  • SGD foreach: 10 seconds, 4.5 seconds on inductor

NUM_PARAMS = 1000

  • SGD: 186 seconds, 24 seconds spent on inductor.

Scaling Hypothesis

Scaling hypothesis for SGD: dynamo: 7s -> 162s ~24x = 5^2. So we expect dynamo cost to scale quadratically with num_params.

This is likely due to the expensive for loop in calling "step".

This is the loop in question:

for p in group['params']:

Conclusion

So if we talk about compile times, Adam is actually much slower. It seems the bottleneck is dynamo.

@jon-chuang
Copy link
Collaborator Author

@mlazos, I am unable to reproduce your results.

Of the optimizers, SGD compiles the fastest.

I applied more optimizations in this branch to make it faster.

foreach=True, N_PARAMS=1000

<class 'torch.optim.sgd.SGD'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics:
Function                              Runtimes (s)
------------------------------------  -------------------------------
_compile.<locals>.compile_inner       0.0401, 28.1882, 0.0048
OutputGraph.call_user_compiler        21.5787

foreach=False, N_PARAMS=1000

<class 'torch.optim.sgd.SGD'> {'lr': 0.01, 'foreach': False} torch.float32 TorchDynamo compilation metrics:
Function                              Runtimes (s)
------------------------------------  -------------------------------
_compile.<locals>.compile_inner       0.0415, 76.9907, 0.0057
OutputGraph.call_user_compiler        23.8749

Which case are you seeing in the benchmarks? (foreach=True or False?)

@jon-chuang jon-chuang changed the title feat(inductor): Add SGD Optimizer back to Inductor feat(inductor): Improve compilation speed and add SGD Optimizer back to Inductor Oct 5, 2023
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values():
device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you move this out to a separate PR, I can approve that as the discussion to re-enable continues here.

Separately, I would love to get optimizer compile time back into the PT2 benchmarks so am in support of getting SGD down at the very least. If general optimizers take too long everywhere, the other option I was thinking about is to try all compilable optimizers for just one model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Standalone: #110648

disabled_multi_tensor_opt_modules = {
adamax,
radam, # data-dependent control flow
sgd, # for now, until we can speed up compilation (this affects the benchmarks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to understand the impact in runtime for benchmarks with this line of change. To my knowledge, a few hours for all the benchmarks is acceptable.

@albanD albanD removed their request for review October 9, 2023 17:38
@janeyx99
Copy link
Contributor

I manually kicked off some benchmarks to see the difference it'd make in #111341 and got these results on the torchinductor HUD page:
image

I'm not entirely well-versed in interpreting these results yet, AND the manual run only ran a subset of configs, but it looks like the speedups do not seem worth it for SGD given the increase in compile-time. This is not news--vanilla SGD is just one instruction, so fusing doesn't do much.

Thus, if the e2e benchmarks on this page are meant to show the "most performant" feature combination, it may be not worth adding compilation to the benchmarks, at least not for SGD. We can revisit adding a compiled optimizer to the benchmarks later.

In the meantime, I do think something that is quite useful to do is pick a subset of models (we could just start with resnet50) and then run all our supported flavors of compiled optimizers. We would want to measure E2E compilation + run time, with breakdowns of each (e.g., dynamo, inductor for compile time). I see at least two ways this could be done:

  • add a separate section in the
    test_perf_for_dashboard() {
  • add a way to swap out the optimizer and rerun a model --> not sure if this will allow us to use the same infra more

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Dec 18, 2023
@github-actions github-actions bot closed this Jan 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor module: dynamo module: inductor open source Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants