Skip to content

StepLR, MultiStepLR, ExponentialLR and CosineAnnealingLR scheduler wrong lr value #20527

@jasam-sheja

Description

@jasam-sheja

When the StepLR, MultiStepLR, ExponentialLR or CosineAnnealingLR scheduler is called with the same epoch parameter the optimizer value is further reduced even though it's the same epoch

a sample code

import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

now if we call lr_scheduler.step(epoch=2) multiple times

for _ in range(10):
    lr_scheduler.step(2)
    print(optimizer.param_groups[0]['lr'])

output:

>>> 0.0001
>>> 1e-05
>>> 1.0000000000000002e-06
>>> 1.0000000000000002e-07
>>> 1.0000000000000004e-08
>>> 1.0000000000000005e-09
>>> 1.0000000000000006e-10
>>> 1.0000000000000006e-11
>>> 1.0000000000000006e-12
>>> 1.0000000000000007e-13

Even if such use-case is bizarre, this is extremely unexpected.
This is happens using PyTorch version 1.1.0 but not 1.0.1. Because pull(#14010) redefined StepLR, MultiStepLR, ExponentialLR and CosineAnnealingLR to directly use the learning rate variable of the optimizer rather than using base_lr defined by _LRScheduler. This was in order to support multiple simultaneous schedulers (#13022)!

Metadata

Metadata

Assignees

Labels

high prioritymodule: optimizerRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions