Skip to content

base_lrs in torch.optim.lr_scheduler.CyclicLR gets overriden by parent class if parameter groups have 'initial_lr' set #21965

@janfreyberg

Description

@janfreyberg

🐛 Bug

One of the arguments of torch.optim.lr_scheduler.CyclicLR is base_lr. This is described in the documentation as:

Initial learning rate which is the lower boundary in the cycle for each parameter group.

When this class actually computes the learning rate on each step, it uses the attribute self.base_lrs. This is set in the parent class, torch.optim.lr_scheduler. _LRScheduler

However, due to the implementation of this (it uses the value 'initial_lr' from each parameter group), this actually produces the wrong behaviour when that 'initial_lr' key is already set. In effect, the CyclicLR scheduler now cycles between whatever value was the initial_lr value for each parameter group, and the max_lr.

This is a problem, as 'initial_lr' gets set by most optimizers, making it impossible to chain multiple optimizers.

To Reproduce

Steps to reproduce the behavior:

import torch
import matplotlib.pyplot as plt

model = torch.nn.Linear(2, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.1)
lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.1, max_lr=0.3, step_size_up=1, step_size_down=3)

lrs = []

for i in range(40):
    if i <= lr_scheduler_1.T_max:
        lr_scheduler_1.step()
    else:
        lr_scheduler_2.step()
    lrs.append(
        optimizer.param_groups[0]["lr"]
    )

plt.plot(lrs)

This produces the following plot:
image

Expected behavior

The base_lr parameter to CyclicLR should override the parent class implementation of the learning rate. CyclicLR should always use the base_lr parameter specified, and not rely on what's present in parameter groups. I would expect the following graph instead:

image

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

$ python collect_env.py

Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.14.5
GCC version: Could not collect
CMake version: Could not collect

Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] gpytorch==0.2.1
[pip3] msgpack-numpy==0.4.3.2
[pip3] numpy==1.16.2
[pip3] numpydoc==0.8.0
[pip3] pytorch-revgrad==0.0.1
[pip3] pytorch-swag==0.0.1
[pip3] torch==1.1.0
[pip3] torch-nightly==1.0.0.dev20190319
[pip3] torch-vision==0.1.6.dev0
[pip3] torchfile==0.1.0
[pip3] torchgeometry==0.1.2
[pip3] torchvision==0.2.3a0+a2e6b70
[pip3] torchvision-nightly==0.2.1
[conda] blas                      1.0                         mkl
[conda] gpytorch                  0.2.1                     <pip>
[conda] mkl                       2019.1                      144
[conda] mkl_fft                   1.0.10           py37h5e564d8_0
[conda] mkl_random                1.0.2            py37h27c97d8_0
[conda] pytorch                   1.0.1                   py3.7_2    pytorch
[conda] pytorch-nightly           1.0.0.dev20190319         py3.7_0    pytorch
[conda] pytorch-revgrad           0.0.1                     <pip>
[conda] pytorch-swag              0.0.1                     <pip>
[conda] torch                     1.1.0                     <pip>
[conda] torch                     1.0.1.post2               <pip>
[conda] torch-nightly             1.0.0.dev20190319           <pip>
[conda] torch-vision              0.1.6.dev0                <pip>
[conda] torchfile                 0.1.0                     <pip>
[conda] torchgeometry             0.1.2                     <pip>
[conda] torchvision               0.2.3a0+a2e6b70           <pip>
[conda] torchvision               0.2.2                      py_3    pytorch
[conda] torchvision-nightly       0.2.1                     <pip>

Metadata

Metadata

Assignees

Labels

module: optimizerRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions