-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Possible bug, definite opportunity to improve documentation
Let's say I want to load the state of an optimizer and CosineAnnealingLR scheduler from a checkpoint. Given an instantiated optimizer, I might naively do the following:
optimizer.load_state_dict(optimizer_state_dict)
scheduler = CosineAnnealingLR(optimizer, T_max=1000)
scheduler.load_state_dict(scheduler_state_dict)But scheduler.get_lr() will not return the expected learning rates in this case. This is because the learning rates belonging to optimizer.param_groups are reset to their initial values when I construct the cosine scheduler, and are not updated when I load the scheduler state dict.
The reason the learning rates are reset is because I fail to pass a value for last_epoch to CosineAnnealingLR. However, it is not obvious that I must do this, since last_epoch will be overwritten with the correct value in the very next line when I load the state dict.
The documentation does say
When
last_epoch=-1, sets initial lr as lr.
but I think it needs to be far more explicit that failure to pass a value for last_epoch is not compensated by immediately loading a state dict that contains the desired value. It doesn't seem intuitive that one needs to use the state dict twice:
scheduler = CosineAnnealingLR(
optimizer,
T_max=1000,
last_epoch=scheduler_state_dict()["last_epoch"],
)
scheduler.load_state_dict(scheduler_state_dict)Alternatively, documentation could instruct the user to instead load the optimizer state dict after instantiating the scheduler:
scheduler = CosineAnnealingLR(optimizer, T_max=1000)
scheduler.optimizer.load_state_dict(optimizer_state_dict)
scheduler.load_state_dict(scheduler_state_dict)Now I've looked at the code, I can see it's probably not going to be classed as a bug. However, it really could do with being more clearly documented. It was unexpected for me because I was previously using CosineAnnealingWarmRestarts which does not suffer from this, since the LR is not computed recursively using scheduler.optimizer.param_groups[*]["lr"].
Finally, even if the user is aware of the fact that CosineAnnealingLR involves a recursive calculation, this behaviour is still not intuitive; the state dict contains _last_lr, which could reasonably be expected to be the LR with respect to which the next value is calculated (at least in my mind).
To Reproduce
Steps to reproduce the behavior:
import torch
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Parameter(torch.rand(1))
def test():
optim = torch.optim.Adam(TestModule().parameters())
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=100)
for _ in range(50):
optim.step()
sched.step()
new_optim = torch.optim.Adam(TestModule().parameters())
new_optim.load_state_dict(optim.state_dict())
new_sched = torch.optim.lr_scheduler.CosineAnnealingLR(new_optim, T_max=100)
new_sched.load_state_dict(sched.state_dict())
res = f"""
sched.
last_epoch \t\t\t\t {sched.last_epoch}
base_lrs \t\t\t\t {sched.base_lrs}
get_lr() \t\t\t\t {sched.get_lr()}
get_last_lr() \t\t\t {sched.get_last_lr()}
optimizer.param_groups[0]['lr'] \t {sched.optimizer.param_groups[0]['lr']}
new_sched.
last_epoch \t\t\t\t {new_sched.last_epoch}
base_lrs \t\t\t\t {new_sched.base_lrs}
get_lr() \t\t\t\t {new_sched.get_lr()}
get_last_lr() \t\t\t {new_sched.get_last_lr()}
optimizer.param_groups[0]['lr'] \t {new_sched.optimizer.param_groups[0]['lr']}
"""
new_optim.step()
new_sched.step()
res += f"""
<new_sched.step()>
new_sched.
last_epoch \t\t\t\t {new_sched.last_epoch}
get_lr() \t\t\t\t {new_sched.get_lr()}
get_last_lr() \t\t\t {new_sched.get_last_lr()}
optimizer.param_groups[0]['lr'] \t {new_sched.optimizer.param_groups[0]['lr']}
"""
print(res)
test()Results in:
sched.
last_epoch 50
base_lrs [0.001]
get_lr() [0.00048477291476666255]
get_last_lr() [0.0005000000000000002]
optimizer.param_groups[0]['lr'] 0.0005000000000000002
new_sched.
last_epoch 50
base_lrs [0.001]
get_lr() [0.0009695458295333247]
get_last_lr() [0.0005000000000000002]
optimizer.param_groups[0]['lr'] 0.001
<new_sched.step()>
new_sched.
last_epoch 51
get_lr() [0.0009381651176296076]
get_last_lr() [0.0009685892409218717]
optimizer.param_groups[0]['lr'] 0.0009685892409218717
Expected behavior
Given that loading a state dict for CosineAnnealingLR overwrites the last_epoch attribute, I would expect either:
(a) The following steps to work
- Instantiate optimizer
- Load optimizer state dict
- Instantiate LR scheduler, passing in loaded optimizer, but leaving
last_epochas default - Load scheduler state dict (overriding
last_epoch)
or (b) it to be clearly documented how to load a state dict for an optimizer and CosineAnnealingLR in a way that ensures the result is actually consistent with the loaded states. From what I gather, this would mean instructing the user to either
- Pass
last_epochinto the constructor forCosineAnnealingLR - Load the optimizer state dict after using it to instantiate the scheduler
Environment
PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.11.0-34-generic-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.9.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] mkl 2021.3.0 h06a4308_520
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.0 py39h42c9631_2
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.20.3 py39hf144106_0
[conda] numpy-base 1.20.3 py39h74d4b33_0
[conda] pytorch 1.9.0 py3.9_cuda10.2_cudnn7.6.5_0 pytorch
cc @brianjo @mruberry @vincentqb @jbschlosser @albanD @iramazanli