-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: optimizerRelated to torch.optimRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
I think there may be a bug with some of the LR schedulers eg StepLR. essentially if the epoch is a multiple of the step size, the LR will keep multiplying by gamma, even if the epoch never increases:
sched = torch.optim.lr_scheduler.StepLR(opt, 2, 2.) # step size = 2, gamma = 2
for i in range(5):
opt.step()
opt.step()
sched.step(epoch=2) # fix the epoch
print(opt.state_dict()['param_groups'][0]['lr'])
2.0
4.0
8.0
16.0
32.0versus if epoch % step_size != 0:
sched = torch.optim.lr_scheduler.StepLR(opt, 2, 2.) # step size = 2, gamma = 2
for i in range(5):
opt.step()
opt.step()
sched.step(epoch=3) # fix the epoch
print(opt.state_dict()['param_groups'][0]['lr'])
1.0
1.0
1.0
1.0
1.0IIUC, calling sched.step(epoch=i) multiple times should be idempotent. Note that this seems to have been introduced in 1.1; 1.0 works as intended.
Pytorch 1.1.0
Python 3.6.5
Metadata
Metadata
Assignees
Labels
module: optimizerRelated to torch.optimRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module