Skip to content

recompilation of AdamW w/ OneCycleLR #133898

@patrick-botco

Description

@patrick-botco

🐛 Describe the bug

i'd like to compile my optimizer but am hitting recompilation issues. i wrap my LR in a tensor, but it seems like beta1/beta2 may need similar treatment (based on type annotations, beta is (Tuple[float, float], optional)); however, wrapping the default beta values in tensors similar to the LR breaks.

a repro:

import torch

torch._logging.set_logs(recompiles_verbose=True)

param = torch.rand(2, 3, dtype=torch.float, device="cuda", requires_grad=True)
param.grad = torch.rand_like(param)

lr = torch.tensor(0.001, device="cuda")
total_steps = 10000
optimizer = torch.optim.AdamW([param], lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=lr, total_steps=total_steps
)

@torch.compile()
def step():
    optimizer.step()
    scheduler.step()


for _ in range(total_steps):
    step()

recompilation

V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     guard 0 failures:
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     - L['self'].param_groups[0]['betas'][0] == 0.9499993141552187 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose] 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     guard 1 failures:
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     - L['self'].param_groups[0]['betas'][0] == 0.9499995610589786 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose] 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     guard 2 failures:
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     - L['self'].param_groups[0]['betas'][0] == 0.9499997530955174 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose] 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     guard 3 failures:
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     - L['self'].param_groups[0]['betas'][0] == 0.9499998902646242 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose] 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     guard 4 failures:
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     - L['self'].param_groups[0]['betas'][0] == 0.9499999725661485 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose] 
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     guard 5 failures:
V0819 21:07:52.605000 139637715699520 torch/_dynamo/guards.py:1423] [__recompiles_verbose]     - L['self'].param_groups[0]['betas'][0] == 0.95

Versions

PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.29.4
Libc version: glibc-2.31

Python version: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1027-oracle-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 550.90.07
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   48 bits physical, 48 bits virtual
CPU(s):                          256
On-line CPU(s) list:             0-254
Off-line CPU(s) list:            255
Thread(s) per core:              1
Core(s) per socket:              64
Socket(s):                       2
NUMA node(s):                    8
Vendor ID:                       AuthenticAMD
CPU family:                      25
Model:                           1
Model name:                      AMD EPYC 7J13 64-Core Processor
Stepping:                        1
Frequency boost:                 enabled
CPU MHz:                         2550.000
CPU max MHz:                     3673.0950
CPU min MHz:                     1500.0000
BogoMIPS:                        4900.16
Virtualization:                  AMD-V
L1d cache:                       2 MiB
L1i cache:                       2 MiB
L2 cache:                        32 MiB
L3 cache:                        256 MiB
NUMA node0 CPU(s):               0-15,128-143
NUMA node1 CPU(s):               16-31,144-159
NUMA node2 CPU(s):               32-47,160-175
NUMA node3 CPU(s):               48-63,176-191
NUMA node4 CPU(s):               64-79,192-207
NUMA node5 CPU(s):               80-95,208-223
NUMA node6 CPU(s):               96-111,224-239
NUMA node7 CPU(s):               112-127,240-254
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] numpy-stl==3.1.1
[pip3] torch==2.3.1
[pip3] torchmetrics==1.3.2
[pip3] torchvision==0.18.1
[pip3] triton==2.3.1
[conda] numpy                     1.24.4                   pypi_0    pypi
[conda] numpy-stl                 3.1.1                    pypi_0    pypi
[conda] torch                     2.3.1                    pypi_0    pypi
[conda] torchmetrics              1.3.2                    pypi_0    pypi
[conda] torchvision               0.18.1                   pypi_0    pypi
[conda] triton                    2.3.1                    pypi_0    pypi

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar @ezyang @chauhang @penguinwu @mlazos

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: optimizerRelated to torch.optimmodule: pt2 optimizerRelating to torch.compile'd optimoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions