-
Notifications
You must be signed in to change notification settings - Fork 26.3k
feat(optimizer): Adagrad will use device when capturable - True always when compiling with dynamo
#110339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(optimizer): Adagrad will use device when capturable - True always when compiling with dynamo
#110339
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110339
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 359aed7 with merge base 428cbd7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Adagrad state param does not follow inputs' device
52123c3 to
b25e2f2
Compare
Adagrad state param does not follow inputs' deviceAdagrad will use device only when capturable - True always when compiling with dynamo
Adagrad will use device only when capturable - True always when compiling with dynamoAdagrad will use device when capturable - True always when compiling with dynamo
|
CC @janeyx99 I applied the "capturable" flag to this PR to remove the cpu tensors from tracing path, as opposed to hardcoding a device tensor. Questions:
|
| ) | ||
| super().__init__(params, defaults) | ||
|
|
||
| for group in self.param_groups: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewer:
Thanks to this change (required for lazily sending "step" to device as needed), a previous error in test/nn/test_lazy_modules.py for lazy initialization is now avoided.
| module.register_parameter('test_param', UninitializedParameter()) | ||
| if optim_cls is torch.optim.SGD: | ||
| optim = optim_cls(module.parameters(), lr=0.0) | ||
| elif optim_cls is torch.optim.Adagrad: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewer:
As mentioned below, as we now lazily initialize step, this previous error in test/nn/test_lazy_modules.py for lazy initialization is now avoided.
… against list comprehensions (e.g. complex conversion) (#110613) Fully fixes: #110506 Depends: #110607 Potential merge conflicts: - #110339 - #110345 - #110454 Related: - #110606 (we can apply the improvements here orthogonally to the complex support) ### Results Benchmark: 100 params. Breakdowns (float32, dynamo): ``` Adagrad: this PR: 4.4s, main: 8.8s Adam: this PR: 2.1s, main: 9.8s AdamW: this PR: 2.5s, main: 8.2s ASGD: this PR: 3.1s, main: 8.5s RMSProp: this PR: 1.3s, main: 4.2s RProp: this PR: 6.7s, main: 14.9s ``` Notes: 1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path 2. Adamax is not actually compiled (it is currently disabled). 3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing. <details> This PR: ``` Adagrad (torch.float32): 28.47496461868286s Adagrad (torch.complex64): 29.379547357559204s Adam (torch.float32): 17.334211587905884s Adam (torch.complex64): 29.637500524520874s Adamax (torch.float32): 2.4749321937561035s Adamax (torch.complex64): 3.1997995376586914s AdamW (torch.float32): 18.06532859802246s AdamW (torch.complex64): 28.25661015510559s ASGD (torch.float32): 23.70255398750305s ASGD (torch.complex64): 25.33756995201111s RMSprop (torch.float32): 7.964028596878052s RMSprop (torch.complex64): 12.909599781036377s Rprop (torch.float32): 30.512362003326416s Rprop (torch.complex64): 44.74405765533447s ``` Main ``` Adagrad (torch.float32): 26.919506072998047s Adagrad (torch.complex64): 35.190622091293335s Adam (torch.float32): 25.715000867843628s Adam (torch.complex64): 24.17716670036316s Adamax (torch.float32): 2.4404726028442383s Adamax (torch.complex64): 3.3538928031921387s AdamW (torch.float32): 25.2022807598114s AdamW (torch.complex64): 28.915700912475586s ASGD (torch.float32): 24.108731985092163s ASGD (torch.complex64): 26.589075088500977s RMSprop (torch.float32): 10.781344175338745s RMSprop (torch.complex64): 15.136352777481079s Rprop (torch.float32): 42.46482181549072s Rprop (torch.complex64): 48.28277635574341s ``` Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs). </details> ### Benchmark Script ```python import torch import time from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop] DTYPES = [torch.float, torch.cfloat] NUM_PARAMS = 100 kwargs = { "lr": 0.01, "foreach": True } summary = [] for optim_cls in OPTIMS: for dtype in DTYPES: torch._dynamo.reset() # torch._inductor.metrics.reset() input = torch.ones([10, 10], dtype=dtype, device="cuda:0") model = torch.nn.Sequential( *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)] ) model(input).sum().abs().backward() opt_compiled = optim_cls(model.parameters(), **kwargs) compiled_step = torch.compile(opt_compiled.step) with torch.set_grad_enabled(False): start_time = time.time() compiled_step() summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s") print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times()) for s in summary: print(s) ``` CC: @janeyx99 @mlazos Pull Request resolved: #110613 Approved by: https://github.com/janeyx99
janeyx99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the capturable path. Adding capturable is heftier as we should support for single tensor as well as adding cuda graphs testing. Feel free to follow the example in #106615 for what is expected. It might make sense to open general capturable in another PR and have this one be the compiler stuff built on top of it.
| state = self.state[p] | ||
| if "step" not in state: | ||
| state["step"] = ( | ||
| torch.zeros((), dtype=torch.float, device=p.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let’s use p.new_zeros to maintain subclass information of p
…=fusion` (#110415) Fixes #110393 Example logs (for adagrad on main). In this case, it clearly identifies device mismatch as a potential red flag, which is indeed the obstacle to adagrad's successful fusion. (see: #110339) ``` [2023-10-03 21:50:24,084] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] ===== attempting fusion (1/10): 18 nodes ===== [2023-10-03 21:50:24,084] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer [2023-10-03 21:50:24,084] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] 13 possible fusions: [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7)) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7), SchedulerNode(name='buf8')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7), SchedulerNode(name='buf10')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), SchedulerNode(name='buf12')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), SchedulerNode(name='buf14')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7), SchedulerNode(name='buf9')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7), SchedulerNode(name='buf11')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), SchedulerNode(name='buf13')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), SchedulerNode(name='buf15')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (SchedulerNode(name='buf25'), SchedulerNode(name='buf33')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (SchedulerNode(name='buf43'), SchedulerNode(name='buf51')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (SchedulerNode(name='buf34'), SchedulerNode(name='buf42')) [2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (SchedulerNode(name='buf16'), SchedulerNode(name='buf24')) [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] completed fusion round (1/10): fused 18 nodes into 5 nodes [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] ===== attempting fusion (2/10): 5 nodes ===== [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu) [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] 0 possible fusions: [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] completed fusion round (2/10): fused 5 nodes into 5 nodes [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] [2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] ===== fusion complete (2 iterations) ===== ``` CC @jansel @ngimel @mlazos @shunting314 @peterbell10 as code owners Pull Request resolved: #110415 Approved by: https://github.com/mlazos
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Partial fix: #107006
CC: @mlazos as issue creator
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler