-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
I really have to say: the new scheduler design in PyTorch1.1 is quite immature. Many bugs exist when using it. Related issue: #21623 (Although I think the usage in that issue is really wired.)
In my case, for MultiStepLR scheduler, get_lr function fails to demonstrate the correct learning rate when the step function is called. In fact, you will get a learning rate which is one more step further decayed. In fact this problem exist for almost all the schedulers in torch1.1
To Reproduce
import torch
from torchvision.models import resnet18
net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)
for i in range(10):
print(i, scheduler.get_lr())
scheduler.step()Expected behavior
0 [0.1]
1 [0.1]
2 [0.1]
3 [0.0010000000000000002]
4 [0.010000000000000002]
5 [0.010000000000000002]
6 [0.00010000000000000003]
7 [0.0010000000000000002]
8 [0.0010000000000000002]
9 [1.0000000000000004e-05]
Appearently, the learning rate produced by scheduler.get_lr() when i in {3, 6, 9} is wrong. 10 times smaller than they should be, which is "one step further decayed"...
Environment
- PyTorch Version : 1.1.0
- Other environments does not matter. It is a code design bug. I will describe it briefly in the next part.
analysis of the bug
From the source code, we can see the reason easily.
In torch.optim.lr_scheduler._LRScheduler, we can find the following function
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lrAnd in torch.optim.lr_scheduler.MultiStepLR, we can find the following function
def get_lr(self):
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
for group in self.optimizer.param_groups]So let's say the lr need to be decayed on epoch 3. After the step() is called at the beginning of epoch 3, all the lr in optimizer.param_groups has already been decayed. And then we call get_lr() from the scheduler to get learning rate for log print, algorighm design or whatever. Ops.. the get_lr() function will return a decayed lr based on the already decayed lr from optimizer.param_groups.
That is why under certain epochs, get_lr() function returns wrong lr which is "one step further decayed"
Of course, there exist one simple solution: just make sure the scheduler.step() is the last function you called during each epoch. This will fix the conflict but seems to be an ugly constraint. Also the design of scheduler actually demand to call scheduler.step() at the very beginning of each epoch.
So this solution is merely a emergency plan, and this is a design bug you need to fix. My recommendation is to recover the self.base_lr in torch0.4.1.
p.s. There is a good news too: this bug does not influence vanilla training process. Merely wrong lr in log file.