Skip to content

Conversation

@vincentqb
Copy link
Contributor

@vincentqb vincentqb commented Sep 18, 2019

Enable chainable schedulers as requested in #13022 by implementing the following changes:

  • Changing the behavior of schedulers to the chainable formula when available (e.g. CosineAnnealingWarmRestart doesn't have a chainable form).
  • Using the closed form whenever epoch is different from None until the next release with a deprecation warning.
  • Making get_last_lr the supported way of obtaining the last computed learning rate by the scheduler.
  • Returning a warning referring to get_last_lr when invoking the get_lr function outside of step

Note that get_lr was used for two purposes prior to chainable schedulers: (Purpose 1) compute the new learning rate value (formal goal), (Purpose 2) read the current learning rate value (informal goal). For chainable schedulers, the two purposes will disagree, and so two distinct functions are needed.

Here is an example showing how to chain learning rate schedulers, and printing the learning rate at each iteration:

import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR, StepLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)

scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = StepLR(optimizer, step_size=5, gamma=0.1)

for epoch in range(10):

    print(epoch, scheduler2.get_last_lr()[0])

    optimizer.step()
    scheduler1.step()
    scheduler2.step()
0 0.1
1 0.09000000000000001
2 0.08100000000000002
3 0.07290000000000002
4 0.06561000000000002
5 0.005904900000000002
6 0.005314410000000002
7 0.004782969000000002
8 0.004304672100000002
9 0.003874204890000002

BC Breaking: Printing learning rate

# Setup for codes below

import warnings
warnings.simplefilter('once', DeprecationWarning) 

import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = StepLR(optimizer, 2)

As discussed in #22107, the user wants to print the learning rate at each iteration.

Version 1.3.1:

for epoch in range(10):
    print(epoch, scheduler.get_lr()[0])
    optimizer.step()
    scheduler.step()
0 0.1
1 0.1
2 0.010000000000000002
3 0.010000000000000002
4 0.0010000000000000002
5 0.0010000000000000002
6 0.00010000000000000003
7 0.00010000000000000003
8 1.0000000000000003e-05
9 1.0000000000000003e-05

Version 1.4.0:

for epoch in range(10):
    print(epoch, scheduler.get_lr()[0])
    optimizer.step()
    scheduler.step()
torch/optim/lr_scheduler.py:342: DeprecationWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
  "please use `get_last_lr()`.", DeprecationWarning)
0 0.1
1 0.1
2 0.0010000000000000002
3 0.010000000000000002
4 0.00010000000000000003
5 0.0010000000000000002
6 1.0000000000000004e-05
7 0.00010000000000000003
8 1.0000000000000004e-06
9 1.0000000000000004e-05

Version 1.4.0:

for epoch in range(10):
    print(epoch, scheduler.get_last_lr()[0])
    optimizer.step()
    scheduler.step()
0 0.1
1 0.1
2 0.010000000000000002
3 0.010000000000000002
4 0.0010000000000000002
5 0.0010000000000000002
6 0.00010000000000000003
7 0.00010000000000000003
8 1.0000000000000004e-05
9 1.0000000000000004e-05

When more than one schedulers are chained, a user may want to know the most recent learning rate computed by a specific scheduler. This is what scheduler.get_last_lr() does. Note that optimizer.param_groups[0]['lr'] was in version 1.3.1 and remains in 1.4.0 a way of getting the current learning rate used in the optimizer.

Deprecation: Epoch parameter

# Setup for codes below

import warnings
warnings.simplefilter('once', DeprecationWarning) 

import torch
from torch.nn import Parameter
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = StepLR(optimizer, 2)

For schedulers that can be chained or that don't require epoch, the epoch parameter from scheduler.step(epoch) is being deprecated in favor of scheduler.step().

Version 1.3.1:

for epoch in range(10):
    optimizer.step()
    scheduler.step(epoch)
    print(epoch, optimizer.param_groups[0]['lr'])
0 0.1
1 0.1
2 0.010000000000000002
3 0.010000000000000002
4 0.0010000000000000002
5 0.0010000000000000002
6 0.00010000000000000003
7 0.00010000000000000003
8 1.0000000000000003e-05
9 1.0000000000000003e-05

Version 1.4.0:

for epoch in range(10):
    optimizer.step()
    scheduler.step(epoch)
    print(epoch, optimizer.param_groups[0]['lr'])
torch/optim/lr_scheduler.py:143: DeprecationWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, DeprecationWarning)
0 0.1
1 0.1
2 0.010000000000000002
3 0.010000000000000002
4 0.0010000000000000002
5 0.0010000000000000002
6 0.00010000000000000003
7 0.00010000000000000003
8 1.0000000000000003e-05
9 1.0000000000000003e-05

Version 1.4.0:

for epoch in range(10):
    print(epoch, optimizer.param_groups[0]['lr'])
    optimizer.step()
    scheduler.step()
0 0.1
1 0.1
2 0.010000000000000002
3 0.010000000000000002
4 0.0010000000000000002
5 0.0010000000000000002
6 0.00010000000000000003
7 0.00010000000000000003
8 1.0000000000000003e-05
9 1.0000000000000003e-05

This last code would also work as expected in version 1.3.1.

If the epoch parameter was used to control the scheduler's flow, this logic will have to be moved to the training loop. For instance, as discussed in #20527:

Version 1.3.1:

for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
    optimizer.step()
    scheduler.step(epoch)
    print(epoch, optimizer.param_groups[0]['lr'])
0 0.1
0 0.1
1 0.1
1 0.1
2 0.010000000000000002
2 0.010000000000000002
3 0.010000000000000002
3 0.010000000000000002

Version 1.4.0:

for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
    optimizer.step()
    scheduler.step(epoch)
    print(epoch, optimizer.param_groups[0]['lr'])
torch/optim/lr_scheduler.py:143: DeprecationWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, DeprecationWarning)
0 0.1
0 0.1
1 0.1
1 0.1
2 0.010000000000000002
2 0.010000000000000002
3 0.010000000000000002
3 0.010000000000000002

Version 1.4.0:

for epoch in range(8):
    print(epoch, optimizer.param_groups[0]['lr'])

    optimizer.step()

    # Step at every other epoch
    if epoch % 2:
        scheduler.step()
0 0.1
0 0.1
1 0.1
1 0.1
2 0.010000000000000002
2 0.010000000000000002
3 0.010000000000000002
3 0.010000000000000002

Changes from #24352

This new PR is based on #24352 which was reverted. The problem was the introduction of a BC breaking change by changing the name of get_lr. This failed on users who built their own schedulers by modifying get_lr. The difference with #24352:

ghstack

Stack from ghstack:

This reverts commit 1c477b7.

Differential Revision: D17460427

@pytorchbot pytorchbot added the module: optimizer Related to torch.optim label Sep 18, 2019
@vincentqb vincentqb changed the title Change schedulers to chainable form (#24352) [WIP] Change schedulers to chainable form (#24352) Sep 18, 2019
@vincentqb vincentqb changed the title [WIP] Change schedulers to chainable form (#24352) [WIP] Change schedulers to chainable form Sep 18, 2019
Enable chainable schedulers as requested in #13022 by implementing the changes mentioned below from [comment](#21800 (comment)).

* Changing the behavior of schedulers to the chainable formula when available
* Using the closed form whenever epoch is different from None until the next release with a deprecation warning
* Making `get_computed_values` the supported way of obtaining the last computed learning rate by the scheduler (see [comment](#21800 (comment)) for new syntax)
* Returning a deprecation warning when invoking the undocumented get_lr function (see [comment](#21800 (comment))) referring to `get_computed_values`, and deprecating it in the next release.
* `CosineAnnealingWarmRestart` still takes an epoch parameter as it is the only one with a mechanic relying on fractional epoch
* `MultiplicativeLR` is consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax. 

# #20527

### Before

The user calls scheduler with a constant epoch either across loops or in the same loop.
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

# Scheduler with sometimes-constant epoch number
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
  lr_scheduler.step(epoch)
  print(optimizer.param_groups[0]['lr'])
```

### After

If the user wants to step
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

last_epoch = -1
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:

  # Check if epoch number has changed manually
  if epoch-last_epoch > 0:
    lr_scheduler.step()
  last_epoch = epoch

  print(epoch, scheduler.get_computed_values())
```

# #22107

### Before

```
import torch
from torchvision.models import resnet18
net = resnet18()

optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
  # Scheduler computes and returns new learning rate, leading to unexpected behavior
  print(i, scheduler.get_lr())
  scheduler.step()
```

### After

```
import torch
from torchvision.models import resnet18

net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
    # Returns last computed learning rate by scheduler
    print(i, lr_scheduler.get_computed_values())
    lr_scheduler.step()
```

# ghstack

This contains the changes from #24352. Opening again since they were reverted.


This reverts commit 1c477b7.

Differential Revision: [D17460427](https://our.internmc.facebook.com/intern/diff/D17460427)
vincentqb added a commit that referenced this pull request Sep 18, 2019
This reverts commit 1c477b7.

ghstack-source-id: ada4310
Pull Request resolved: #26423
@vadimkantorov
Copy link
Contributor

optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

this looks like magic. it seems that schedulers are attached to the optimizer, but this is not explicit. is this what's happening?

@vincentqb
Copy link
Contributor Author

optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

this looks like magic. it seems that schedulers are attached to the optimizer, but this is not explicit. is this what's happening?

optimizer is the first parameter passed to MultiStepLR and StepLR, is this what you refer to?

@vadimkantorov
Copy link
Contributor

I think I misunderstand that example, sorry for the bother... I thought that somehow the two created schedulers interact.

Enable chainable schedulers as requested in #13022 by implementing the changes mentioned below from [comment](#21800 (comment)).

* Changing the behavior of schedulers to the chainable formula when available
* Using the closed form whenever epoch is different from None until the next release with a deprecation warning
* Making `get_computed_values` the supported way of obtaining the last computed learning rate by the scheduler (see [comment](#21800 (comment)) for new syntax)
* Returning a deprecation warning when invoking the undocumented get_lr function (see [comment](#21800 (comment))) referring to `get_computed_values`, and deprecating it in the next release.
* `CosineAnnealingWarmRestart` still takes an epoch parameter as it is the only one with a mechanic relying on fractional epoch
* `MultiplicativeLR` is consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax. 

# #20527

### Before

The user calls scheduler with a constant epoch either across loops or in the same loop.
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

# Scheduler with sometimes-constant epoch number
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
  lr_scheduler.step(epoch)
  print(optimizer.param_groups[0]['lr'])
```

### After

If the user wants to step
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

last_epoch = -1
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:

  # Check if epoch number has changed manually
  if epoch-last_epoch > 0:
    lr_scheduler.step()
  last_epoch = epoch

  print(epoch, scheduler.get_computed_values())
```

# #22107

### Before

```
import torch
from torchvision.models import resnet18
net = resnet18()

optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
  # Scheduler computes and returns new learning rate, leading to unexpected behavior
  print(i, scheduler.get_lr())
  scheduler.step()
```

### After

```
import torch
from torchvision.models import resnet18

net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
    # Returns last computed learning rate by scheduler
    print(i, lr_scheduler.get_computed_values())
    lr_scheduler.step()
```

# ghstack

This contains the changes from #24352. Opening again since they were reverted.


This reverts commit 1c477b7.

Differential Revision: [D17460427](https://our.internmc.facebook.com/intern/diff/D17460427)
vincentqb added a commit that referenced this pull request Sep 23, 2019
This reverts commit 1c477b7.

ghstack-source-id: 2c8233c
Pull Request resolved: #26423
@vincentqb vincentqb changed the title [WIP] Change schedulers to chainable form Change schedulers to chainable form Sep 24, 2019
@vincentqb vincentqb requested a review from fmassa September 24, 2019 17:11
@vincentqb
Copy link
Contributor Author

This PR (unlike #24352) preserves the name of get_lr but throws a warning if the users call it outside of step, and refers to the new interface to get the most recent learning rate, get_last_lr.

@ezyang
Copy link
Contributor

ezyang commented Sep 27, 2019

How did you test that this no longer breaks the projects that were affected last revert? On V3 of the diff internally, it does not look you triggered deferred tests.

self._get_lr_called_within_step = True
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self._get_lr_called_within_step = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a more normal way to implement this to just have get_latest_lr throw no warning (and get_lr point at it with a warning), so you don't have to muck around with this attribute?

Copy link
Contributor Author

@vincentqb vincentqb Sep 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, you are suggesting to just make get_lr throw an error when called? get_lr is always called within step, and we don't want an error thrown in this case. This attribute is meant to check that get_lr is called within step and avoid throwing an error in this case. Is that what you meant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, you are suggesting to just make get_lr throw an error when called?

If by "throw an error" you mean "raise a warning", yes, I am suggesting that. However, I am also suggesting that you change the line:

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):

To be:

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr_nowarn()):

where get_lr_nowarn is a version of get_lr that has the same semantics, but doesn't do a warning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another way of doing this would be with a context manager?

with self.enable_get_lr_call():
   for param_group, lr in zip(..., self.get_lr()):
        ...

which basically do what you are doing internally, but handles the case where there is a failure in get_lr and the _get_lr_called_within_step is not set back to False.
But I think that this is a corner case, and is not really necessary

Copy link
Contributor Author

@vincentqb vincentqb Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I am also suggesting that you change the line:

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):

To be:

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr_nowarn()):

where get_lr_nowarn is a version of get_lr that has the same semantics, but doesn't do a warning.

I see, so you are proposing (1) step uses get_lr_no_warning which computes learning rate without warning, (2) get_lr raises a warning and calls get_lr_no_warning.

However, users already overwrite get_lr with their own, and expect step to invoke it without modification. What I described just now would then break those codes. Thoughts?

I think another way of doing this would be with a context manager?

Fair :) I'll use that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair :) I'll use that.

Context manager is a strict improvement because it will handle exceptions correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, users already overwrite get_lr with their own, and expect step to invoke it without modification. What I described just now would then break those codes. Thoughts?

Err, if this is true, you're not going to get warnings in this case, because they've overridden your warnings? It is getting a bit hard to keep track of all the cases. Is there a clear description of how these changes interact with user defined schedulers?

Copy link
Contributor Author

@vincentqb vincentqb Oct 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, if this is true, you're not going to get warnings in this case, because they've overridden your warnings? It is getting a bit hard to keep track of all the cases. Is there a clear description of how these changes interact with user defined schedulers?

I don't need to provide a warning in this case. My goal is just to alert to the changes that happened.

  • Users should be alerted that some existing schedulers changed from closed form to chainable form.
  • User-modified schedulers should continue to work as expected.
  • Users should be discouraged from providing epoch when not needed.
  • Users should be encouraged to use get_last_lr to get the last computed value instead of get_lr (e.g. by the warning discussed here), since the modified get_lr wouldn't returned the last lr for chainable schedulers.

Since we haven't changed how step use get_lr and a new get_last_lr, we do not need to give a warning for user-modified get_lr in schedulers.

Context manager is a strict improvement because it will handle exceptions correctly.

Agree.

@vincentqb
Copy link
Contributor Author

How did you test that this no longer breaks the projects that were affected last revert? On V3 of the diff internally, it does not look you triggered deferred tests.

The evidence that projects no longer break is that all internal tests are passing. See T53988247 (internal) for internal tests that were failing on #24352.

The deferred internal tests are now running.

@vincentqb
Copy link
Contributor Author

How did you test that this no longer breaks the projects that were affected last revert? On V3 of the diff internally, it does not look you triggered deferred tests.

@ezyang -- I ran the test listed in T53988247, and passed :D

epochs = 35
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
warnings.filterwarnings("ignore", category=DeprecationWarning)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the warning filtering here new? Why are these tests raising deprecation warnings? Shouldn't our test code be using the non-deprecated codepath?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I'll remove them.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, thanks a lot for all the work @vincentqb !

I have made a few minor comments, none of which IMO is blocking this PR to be merged.

self._get_lr_called_within_step = True
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self._get_lr_called_within_step = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another way of doing this would be with a context manager?

with self.enable_get_lr_call():
   for param_group, lr in zip(..., self.get_lr()):
        ...

which basically do what you are doing internally, but handles the case where there is a failure in get_lr and the _get_lr_called_within_step is not set back to False.
But I think that this is a corner case, and is not really necessary

return (end - start) * pct + start

def get_lr(self):
if not self._get_lr_called_within_step:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a bit repetitive (although very explicit).

I'm not sure if it would be a good idea to add another private method _check_get_lr_called_within_step, which raises the warning for you?

Don't need to act on this, it's more a question.

Copy link
Contributor Author

@vincentqb vincentqb Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree with you. However, "repetition" has been the strategy used in other parts of schedulers to keep them readable, see each block in step that has been repeated.

schedulers = [None] * 2
schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
mode='min', threshold=0.1)
schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would make sense to create a CombineScheduler class in a follow-up PR, which takes an iterable of schedulers and apply them in order.

Basic gist

class CombinedScheduler(object):
    def __init__(self, schedulers):
        # assert all schedulers are for the same optimizer?
        self.schedulers = list(schedulers)

    def step():
        for scheduler in self.schedulers:
            scheduler.step()

Instead of a list of schedulers, we could even pass a map of schedulers, so that at step call we could dynamically decide on which schedulers we want to step on, by name.

But this is out of the scope of this PR, we should probably open an issue after this PR gets merged to discuss about it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting idea :) but we can already change dynamically the schedulers simply using python conditional without a new CombineScheduler, no?

Is the CombineScheduler you propose different from what we discussed before?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think more explicit form of chaining is more understandable over creating multiple schedulers and chaining happening implicitly

for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]


class MultiplicativeLR(_LRScheduler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could live in a separate PR, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could, yes. The reason I added it in is that LambdaLR felt a little out of place alone since it was not chainable. Do you recommend doing so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiplicativeLR moved to #27254

"""
self.__dict__.update(state_dict)

def get_last_lr(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with having it being a method called get_last_lr, but I wonder if it was a @property called lr it wouldn't be more natural to users? Like

print(i, scheduler.lr)

Copy link
Contributor Author

@vincentqb vincentqb Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I went for get_last_lr was to also emphasize that get_lr does not return what the user would expect from it, the last computed lr. If there's an lr property, then it looks like get_lr should return the learning rate.

I originally wanted to change the name get_lr to something like compute_values to

  • make it clear the goal is to compute the new learning rates, and
  • remove the reference to "lr" so as to be more general later (so that we can have weight schedulers with identical syntax). I'll leave the latter for a separate PR.

That led to many BC breaking change internally, since get_lr is used to compute the new learning rate in step, so I maintained the name get_lr to compute the new learning rate and introduced get_last_lr to get the most recent learning rate.

Thoughts?

@vadimkantorov
Copy link
Contributor

The current idea is to still have schedulers API specific to learning rate scheduling? (instead of a more generic lr / weight decay schedulers)

@vincentqb
Copy link
Contributor Author

The current idea is to still have schedulers API specific to learning rate scheduling? (instead of a more generic lr / weight decay schedulers)

The change to chainable scheduler is BC breaking, and so is the change of names (for get_lr, get_last_lr) if we go the route of generic scheduler. This PR only addresses the first issue of chainability, and is not making an opinion on the latter.

Enable chainable schedulers as requested in #13022 by implementing the following changes:

* Changing the behavior of schedulers to the chainable formula when available (e.g. `CosineAnnealingWarmRestart` doesn't have a chainable form).
* Using the closed form whenever epoch is different from None until the next release with a deprecation warning.
* Making `get_last_lr` the supported way of obtaining the last computed learning rate by the scheduler.
* Returning a warning referring to `get_last_lr` when invoking the `get_lr` function outside of `step`
* `MultiplicativeLR` consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax. 

Note that `get_lr` was used for two purposes prior to chainable schedulers: (Purpose 1) compute the new learning rate value (formal goal), (Purpose 2) read the current learning rate value (informal goal). For chainable schedulers, the two  purposes will disagree, and so two distinct functions are needed. 

# Changes from #24352

This new PR is based on #24352 which was reverted. The problem was the introduction of a BC breaking change by changing the name of `get_lr`. This failed on users who built their own schedulers by modifying `get_lr`. The difference with #24352:

* In this PR, (Purpose 1) remains `get_lr`, and (Purpose 2) is now `get_last_lr`
* In #24352, (Purpose 1) was `_compute values`, and (Purpose 2) `get_last_computed_values`.

# #20527

### Before

The user calls scheduler with a constant epoch either across loops or in the same loop.
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

# Scheduler with sometimes-constant epoch number
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
  lr_scheduler.step(epoch)
  print(optimizer.param_groups[0]['lr'])
```

### After

If the user wants to step
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

last_epoch = -1
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:

  # Check if epoch number has changed manually
  if epoch-last_epoch > 0:
    lr_scheduler.step()
  last_epoch = epoch

  print(epoch, scheduler.get_last_lr())
```

# #22107

### Before

```
import torch
from torchvision.models import resnet18
net = resnet18()

optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
  # Scheduler computes and returns new learning rate, leading to unexpected behavior
  print(i, scheduler.get_lr())
  scheduler.step()
```

### After

```
import torch
from torchvision.models import resnet18

net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
    # Returns last computed learning rate by scheduler
    print(i, lr_scheduler.get_last_lr())
    lr_scheduler.step()
```

# ghstack


This reverts commit 1c477b7.

Differential Revision: [D17460427](https://our.internmc.facebook.com/intern/diff/D17460427)
vincentqb added a commit that referenced this pull request Oct 1, 2019
This reverts commit 1c477b7.

ghstack-source-id: 6179ce4
Pull Request resolved: #26423
Enable chainable schedulers as requested in #13022 by implementing the following changes:

* Changing the behavior of schedulers to the chainable formula when available (e.g. `CosineAnnealingWarmRestart` doesn't have a chainable form).
* Using the closed form whenever epoch is different from None until the next release with a deprecation warning.
* Making `get_last_lr` the supported way of obtaining the last computed learning rate by the scheduler.
* Returning a warning referring to `get_last_lr` when invoking the `get_lr` function outside of `step`
* `MultiplicativeLR` consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax. 

Note that `get_lr` was used for two purposes prior to chainable schedulers: (Purpose 1) compute the new learning rate value (formal goal), (Purpose 2) read the current learning rate value (informal goal). For chainable schedulers, the two  purposes will disagree, and so two distinct functions are needed. 

# Changes from #24352

This new PR is based on #24352 which was reverted. The problem was the introduction of a BC breaking change by changing the name of `get_lr`. This failed on users who built their own schedulers by modifying `get_lr`. The difference with #24352:

* In this PR, (Purpose 1) remains `get_lr`, and (Purpose 2) is now `get_last_lr`
* In #24352, (Purpose 1) was `_compute values`, and (Purpose 2) `get_last_computed_values`.

# #20527

### Before

The user calls scheduler with a constant epoch either across loops or in the same loop.
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

# Scheduler with sometimes-constant epoch number
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
  lr_scheduler.step(epoch)
  print(optimizer.param_groups[0]['lr'])
```

### After

If the user wants to step
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

last_epoch = -1
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:

  # Check if epoch number has changed manually
  if epoch-last_epoch > 0:
    lr_scheduler.step()
  last_epoch = epoch

  print(epoch, scheduler.get_last_lr())
```

# #22107

### Before

```
import torch
from torchvision.models import resnet18
net = resnet18()

optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
  # Scheduler computes and returns new learning rate, leading to unexpected behavior
  print(i, scheduler.get_lr())
  scheduler.step()
```

### After

```
import torch
from torchvision.models import resnet18

net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
    # Returns last computed learning rate by scheduler
    print(i, lr_scheduler.get_last_lr())
    lr_scheduler.step()
```

# ghstack


This reverts commit 1c477b7.

Differential Revision: [D17460427](https://our.internmc.facebook.com/intern/diff/D17460427)
Enable chainable schedulers as requested in #13022 by implementing the following changes:

* Changing the behavior of schedulers to the chainable formula when available (e.g. `CosineAnnealingWarmRestart` doesn't have a chainable form).
* Using the closed form whenever epoch is different from None until the next release with a deprecation warning.
* Making `get_last_lr` the supported way of obtaining the last computed learning rate by the scheduler.
* Returning a warning referring to `get_last_lr` when invoking the `get_lr` function outside of `step`
* `MultiplicativeLR` consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax. 

Note that `get_lr` was used for two purposes prior to chainable schedulers: (Purpose 1) compute the new learning rate value (formal goal), (Purpose 2) read the current learning rate value (informal goal). For chainable schedulers, the two  purposes will disagree, and so two distinct functions are needed. 

# Changes from #24352

This new PR is based on #24352 which was reverted. The problem was the introduction of a BC breaking change by changing the name of `get_lr`. This failed on users who built their own schedulers by modifying `get_lr`. The difference with #24352:

* In this PR, (Purpose 1) remains `get_lr`, and (Purpose 2) is now `get_last_lr`
* In #24352, (Purpose 1) was `_compute values`, and (Purpose 2) `get_last_computed_values`.

# #20527

### Before

The user calls scheduler with a constant epoch either across loops or in the same loop.
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

# Scheduler with sometimes-constant epoch number
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
  lr_scheduler.step(epoch)
  print(optimizer.param_groups[0]['lr'])
```

### After

If the user wants to step
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters()) 
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

last_epoch = -1
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:

  # Check if epoch number has changed manually
  if epoch-last_epoch > 0:
    lr_scheduler.step()
  last_epoch = epoch

  print(epoch, scheduler.get_last_lr())
```

# #22107

### Before

```
import torch
from torchvision.models import resnet18
net = resnet18()

optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
  # Scheduler computes and returns new learning rate, leading to unexpected behavior
  print(i, scheduler.get_lr())
  scheduler.step()
```

### After

```
import torch
from torchvision.models import resnet18

net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
    # Returns last computed learning rate by scheduler
    print(i, lr_scheduler.get_last_lr())
    lr_scheduler.step()
```

# ghstack


This reverts commit 1c477b7.

Differential Revision: [D17460427](https://our.internmc.facebook.com/intern/diff/D17460427)
@facebook-github-bot facebook-github-bot deleted the gh/vincentqb/25/head branch October 28, 2019 22:21
pdlive215 pushed a commit to pdlive215/pytorch that referenced this pull request Nov 27, 2019
Summary:
Pull Request resolved: pytorch#26423

Enable chainable schedulers as requested in pytorch#13022 by implementing the changes mentioned below from [comment](pytorch#21800 (comment)).

* Changing the behavior of schedulers to the chainable formula when available
* Using the closed form whenever epoch is different from None until the next release with a deprecation warning
* Making `get_computed_values` the supported way of obtaining the last computed learning rate by the scheduler (see [comment](pytorch#21800 (comment)) for new syntax)
* Returning a deprecation warning when invoking the undocumented get_lr function (see [comment](pytorch#21800 (comment))) referring to `get_computed_values`, and deprecating it in the next release.
* `CosineAnnealingWarmRestart` still takes an epoch parameter as it is the only one with a mechanic relying on fractional epoch
* `MultiplicativeLR` is consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax.

# pytorch#20527

### Before

The user calls scheduler with a constant epoch either across loops or in the same loop.
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters())
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

# Scheduler with sometimes-constant epoch number
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
  lr_scheduler.step(epoch)
  print(optimizer.param_groups[0]['lr'])
```

### After

If the user wants to step
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters())
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

last_epoch = -1
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:

  # Check if epoch number has changed manually
  if epoch-last_epoch > 0:
    lr_scheduler.step()
  last_epoch = epoch

  print(epoch, scheduler.get_computed_values())
```

# pytorch#22107

### Before

```
import torch
from torchvision.models import resnet18
net = resnet18()

optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
  # Scheduler computes and returns new learning rate, leading to unexpected behavior
  print(i, scheduler.get_lr())
  scheduler.step()
```

### After

```
import torch
from torchvision.models import resnet18

net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
    # Returns last computed learning rate by scheduler
    print(i, lr_scheduler.get_computed_values())
    lr_scheduler.step()
```

# ghstack

This contains the changes from pytorch#24352. Opening again since they were reverted.

This reverts commit 1c477b7.

Test Plan: Imported from OSS

Differential Revision: D17460427

Pulled By: vincentqb

fbshipit-source-id: 8c10f4e7246d6756ac91df734e8bed65bdef63c9
facebook-github-bot pushed a commit that referenced this pull request Dec 23, 2019
Summary:
Fixes #29697.

Raise warning for schedulers following chainable schedulers in #26423. See explanation for
* [new warning when load/save](#29697 (comment))
* [change from deprecation to user warning](#29697 (comment)).

gchanan -- This should go in the upcoming release following #26423.
Pull Request resolved: #31125

Differential Revision: D19143740

Pulled By: vincentqb

fbshipit-source-id: 35b55fe6c5b39ca5a68b1a6e19f14eb95b9a784e
@vincentqb
Copy link
Contributor Author

Adding "When loading the state of a scheduler, the optimizer also needs to be restored." to BC-breaking, as discussed in #29697.

wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
…1125)

Summary:
Fixes pytorch#29697.

Raise warning for schedulers following chainable schedulers in pytorch#26423. See explanation for
* [new warning when load/save](pytorch#29697 (comment))
* [change from deprecation to user warning](pytorch#29697 (comment)).

gchanan -- This should go in the upcoming release following pytorch#26423.
Pull Request resolved: pytorch#31125

Differential Revision: D19143740

Pulled By: vincentqb

fbshipit-source-id: 35b55fe6c5b39ca5a68b1a6e19f14eb95b9a784e
thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
Summary:
Pull Request resolved: pytorch#26423

Enable chainable schedulers as requested in pytorch#13022 by implementing the changes mentioned below from [comment](pytorch#21800 (comment)).

* Changing the behavior of schedulers to the chainable formula when available
* Using the closed form whenever epoch is different from None until the next release with a deprecation warning
* Making `get_computed_values` the supported way of obtaining the last computed learning rate by the scheduler (see [comment](pytorch#21800 (comment)) for new syntax)
* Returning a deprecation warning when invoking the undocumented get_lr function (see [comment](pytorch#21800 (comment))) referring to `get_computed_values`, and deprecating it in the next release.
* `CosineAnnealingWarmRestart` still takes an epoch parameter as it is the only one with a mechanic relying on fractional epoch
* `MultiplicativeLR` is consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax.

# pytorch#20527

### Before

The user calls scheduler with a constant epoch either across loops or in the same loop.
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters())
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

# Scheduler with sometimes-constant epoch number
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:
  lr_scheduler.step(epoch)
  print(optimizer.param_groups[0]['lr'])
```

### After

If the user wants to step
```
import torch.optim as optim
from torch import nn

conv = nn.Conv2d(3,3,3)
optimizer = optim.Adam(conv.parameters())
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2)

last_epoch = -1
for epoch in [0, 0, 1, 1, 2, 2, 3, 3]:

  # Check if epoch number has changed manually
  if epoch-last_epoch > 0:
    lr_scheduler.step()
  last_epoch = epoch

  print(epoch, scheduler.get_computed_values())
```

# pytorch#22107

### Before

```
import torch
from torchvision.models import resnet18
net = resnet18()

optimizer = torch.optim.SGD(net.parameters(), 0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
  # Scheduler computes and returns new learning rate, leading to unexpected behavior
  print(i, scheduler.get_lr())
  scheduler.step()
```

### After

```
import torch
from torchvision.models import resnet18

net = resnet18()
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

for i in range(10):
    # Returns last computed learning rate by scheduler
    print(i, lr_scheduler.get_computed_values())
    lr_scheduler.step()
```

# ghstack

This contains the changes from pytorch#24352. Opening again since they were reverted.

This reverts commit 1c477b7.

Test Plan: Imported from OSS

Differential Revision: D17460427

Pulled By: vincentqb

fbshipit-source-id: 8c10f4e7246d6756ac91df734e8bed65bdef63c9
@vincentqb vincentqb mentioned this pull request Jul 22, 2020
@franchesoni
Copy link

Any clue on which schedulers can be concatenated with ReduceLROnPlateau?

If I use CosineAnnealingLR instead of CyclicLR then the code works, but there is the following missing functionality. How could I change the max learning rate in CyclicLR or in CosineAnnealingLR with the same criteria as implemented by ReduceLROnPlateau?

import torch
from torch.nn import Parameter
from torch.optim import SGD

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)

# scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100)
scheduler2 = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.1, 1, step_size_up=100)

values1, values2 = [], []
for epoch in range(1000):
    values1.append(scheduler2.get_last_lr()[0])
    optimizer.step()
    scheduler1.step(epoch)
    scheduler2.step()

import matplotlib.pyplot as plt
plt.plot(values1, label='values 1')
plt.legend()
plt.show()

facebook-github-bot pushed a commit that referenced this pull request Aug 26, 2021
Summary:
In this PR we are introducing ChainedScheduler which initially proposed in the discussion #26423 (comment) .

The idea is to provide a user friendly chaining method for schedulers, especially for the cases many of them are involved and we want to have a clean and easy to read interface for schedulers. This method will be even more crucial once CompositeSchedulers and Schedulers for different type of parameters are involved.

The immediate application of Chained Scheduler is expected to happen in TorchVision Library to combine WarmUpLR and  MultiStepLR https://github.com/pytorch/vision/blob/master/references/video_classification/scheduler.py#L5 . However, it can be expected that in many other use cases also this method could be applied.

### Example
The usage is as simple as below:

```python
sched=ChainedScheduler([ExponentialLR(self.opt, gamma=0.9),
                        WarmUpLR(self.opt, warmup_factor=0.2, warmup_iters=4, warmup_method="constant"),
                        StepLR(self.opt, gamma=0.1, step_size=3)])
```

Then calling
```python
sched.step()
```
would trigger step function for all three schedulers consecutively

Partially resolves pytorch/vision#4281

Pull Request resolved: #63491

Reviewed By: datumbox, mruberry

Differential Revision: D30576180

Pulled By: iramazanli

fbshipit-source-id: b43f0749f55faab25079641b7d91c21a891a87e4
Tony-Y added a commit to Tony-Y/pytorch_warmup that referenced this pull request Apr 7, 2022
"chaining" is supported for PyTorch 1.4 or above: 
pytorch/pytorch#26423
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: bc-breaking Related to a BC-breaking change module: optimizer Related to torch.optim

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants