Skip to content

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Jan 1, 2018

I have added the adamw and sgdw flags to the appropriate optimizers rather than their own ones for issue #3790

Instead of defining new optimizers as in PR #3740 I am fixing the weight decay in the appropriate optimizers.

My only issue is that now the comparison tests between the older legacy optimizer and this one fails.

@HaraldKorneliussen
Copy link

HaraldKorneliussen commented Jan 4, 2018

It looks like it's just the flake test complaining about extra whitespace.
And the way to fix the failing "equivalent to lua torch"-test, I guess is to fix it in lua torch first? :)

Or just remove the weight decay test in Adam from https://github.com/pytorch/pytorch/blob/master/test/optim/tests.json if the devs are OK with giving up lua torch compatibility on that point.

@kashif
Copy link
Contributor Author

kashif commented Jan 4, 2018

@HaraldKorneliussen thanks! Yes that is the question, if I should also change the legacy code or remove the test?

Copy link
Contributor Author

@kashif kashif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I removed the unnecessary clones

@kashif kashif changed the title added adamw and sgdw fixed weight decay in optimizers (added adamw and sgdw among others) Jan 6, 2018
@lucasb-eyer
Copy link
Contributor

I would like to see this merged, as I think that paper's result was an important one. I also prefer "fixing" the existing optimizers as in this PR over the huge copy-paste in the other PR.

As it currently is implemented, I have never encountered a case where weight-decay has helped me in any of the non-SGD/Momentum optimizers, nor has anyone I have talked to.

That being said, we might want to keep the original behaviour by default, at the very least for SGD/Momentum. Imagine all the training codes where the performance changes (for the better or the worse) after an update, and how near-impossible it would be to find out that this PR is why! Then, read the paper, and fix the factors accordingly. So I'd rather see a new flag for it, something like separate_weight_decay=False by default.

Or maybe print a (one-time per session) warning whenever the optimizer is created with non-zero weight_decay for the next handful of versions.

@apaszke
Copy link
Contributor

apaszke commented Jan 20, 2018

I'm a bit concerned about this PR, because it completely changes how optimizers work, and may hurt reproducibility. Any opinions @soumith?

@kashif
Copy link
Contributor Author

kashif commented Jan 20, 2018

I agree @apaszke it will cause issues with reproducibility. What is the pytorch design principle? to stay backwards compatible? I could add a flag that switches from default to this behaviour as @lucasb-eyer suggests...

@lucasb-eyer
Copy link
Contributor

@apaszke you are right in that it would be different from most (all?) popular frameworks. I haven't seen enough confirmation of its effectiveness in the wild (have you, @kashif?) so maybe the conservative separate_weight_decay=False is the better thing to do here. But no matter how, I'd really like to have this in PyTorch.

Or be the bold Apple of DL and move to Doing It Right(TM), just with a big fat warning as I mentioned here 😄

@kashif
Copy link
Contributor Author

kashif commented Jan 21, 2018

@lucasb-eyer as far as I am aware only chainer head has this implemented currently for the adam optimizer by default. No other optimizers have weight decay so they have not added it to them, like I did here...

In terms of experiments, the ones we ran for chainer are here:

@jpeg729
Copy link

jpeg729 commented Mar 24, 2018

This implementation doesn't follow the paper Fixing Weight Decay Regularization in Adam. This pull request applies weight decay using

p -= weight_decay * p

whereas the paper proposes applying weight decay using

p -= lr * weight_decay * p

c.f. Algorithm 2 on page 3.
Is there a strong reason for not following the paper's proposal to the letter?

@kashif the problem with weight decay was that it was applied before the smoothing stuff, not that weight_decay was getting mixed up with the learning rate.

@kashif
Copy link
Contributor Author

kashif commented Mar 26, 2018

@jpeg729 the pytorch api makes it seem like I am not multiplying the learning rate, especially the inplace APIs, but if you see the api call bellow that, then you will see that the learning rate gets multiplied etc. as in the paper... Thanks and hope it helps!

@jpeg729
Copy link

jpeg729 commented Mar 26, 2018

@kashif I beg to differ.
For example, from your modified code for SGD. I have added comments according to my understanding of the pytorch inplace apis.

            if weight_decay != 0:
                p.data.add_(-weight_decay, p.data) # p.data = p.data + weight_decay * p.data

            p.data.add_(-group['lr'], d_p) # p.data = p.data + lr * d_p

Only d_p gets multiplied by the learning rate, not the weight_decay update.
As far as I can see, none of the optimizers multiply the weight_decay by the learning rate.


if grad.is_sparse:
if group['weight_decay'] != 0:
raise RuntimeError("weight_decay option is not compatible with sparse gradients")

This comment was marked as off-topic.

This comment was marked as off-topic.

@kashif
Copy link
Contributor Author

kashif commented Mar 26, 2018

@jpeg729 ok so I fixed the negatives:

if weight_decay != 0:
    p.data.add_(-weight_decay, p.data) # p.data = p.data - weight_decay * p.data

p.data.add_(-group['lr'], d_p) # p.data = p.data - lr * d_p = p.data -lr * d_p - weight_decay * p.data

which is essentially line 9 in Algorithm 1 of the paper. Could it be that you are thinking of the eta_t in the paper as the learning rate? eta_t in the paper is the learning rate schedular which here is just 1.

@jpeg729
Copy link

jpeg729 commented Mar 26, 2018

@kashif You are right. eta is the Scheduler.
But, I think that there is a problem here that requires some thought.
In the paper at time t the effective learning rate is eta * alpha, and the effective weight_decay is eta * w.

So if someone uses a learning rate scheduler with these optimisers, then they have to manually apply a similar scheduler to the weight decay in order to replicate the methods in the paper. Otherwise the grad updates would decay but not the weight decay updates, and eventually the grad updates would be swamped by the weight decay updates.

One fix would be to add param_group['eta'] to each optimiser and change the schedulers to modify param_group['eta'] instead of param_group['lr'].

@kashif
Copy link
Contributor Author

kashif commented Mar 26, 2018

@jpeg729 yes you are right, at the moment the schedular will only change the learning rate and not the weight decay parameter... let me see how hard it will be to add that.. but as you can read from the comments above I dont think these changes are going to be accepted so not sure how much its worth investing time on this... I might close this PR soon.

@lucasb-eyer
Copy link
Contributor

@kashif if this PR does not get merged (likely), do you plan to create a separate repo that trackes optim but with these changes? Something like optimw, or will you just let it die and not use these changes anymore yourself?

@kashif kashif changed the title fixed weight decay in optimizers (added adamw and sgdw among others) Decoupled Weight Decay Regularization in optimizers (added adamw and sgdw among others) Jan 17, 2019
weight_decay (float, optional): weight decay factor (default: 0)
l2_reg (boolean, optional): whether to use the original L2
weight regularization or the weight decay method from the paper
`Decoupled Weight Decay Regularization` (default: True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not obviously clear by the text which option is which. Maybe state it, like

            l2_reg (boolean, optional): whether to use the original L2
            weight regularization (l2_reg=True) or the weight decay method from the paper
            `Decoupled Weight Decay Regularization` (l2_reg=False) (default: True)

@bhack
Copy link
Contributor

bhack commented Mar 30, 2019

/cc @loshchil

@ezyang ezyang removed the ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes label Apr 3, 2019
@loshchil
Copy link

@jpeg729 As a workaround using the current code here, will it correctly apply scheduling on the weight decay if I set

weight_decay = actual_weight_decay / initial_lr
and correspondingly, in the optimizer,
p.data.add_(-group['weight_decay'] * group['lr'], p.data)

Such that the initial learning rate get cancelled out and is left with eta?

I didn't try it but I think that you are right @hkchengrex , this should work and implement eta as described in the paper. Otherwise, the current code has the issue described above by @jpeg729
At first, it might seem unusual that weight decay is scheduled. However, a closer look at the original SGD with L2-based weight decay will reveal that weight decay is scheduled there as well because the weight decay part is first combined with the loss-based gradient and then multiplied by learning rate which is usually scheduled.
Please note note that the optimal value of weight decay in SGDW is likely to be different by a factor of lr from the one of SGD with L2, see Section 2 of https://arxiv.org/pdf/1711.05101.pdf.

@soumith
Copy link
Contributor

soumith commented Jun 20, 2019

cc: @vincentqb can you review this and get it to completion if it makes sense. Use the guidelines that I shared with you separately.

@kashif i want to apologize for never getting a review to completion, but we now have Vincent with a lot of bandwidth and who is a math and optimization expert. He will help with the review.

@kashif
Copy link
Contributor Author

kashif commented Jun 20, 2019

thanks @soumith I understand the delay. I will try to work with @vincentqb to get it ready for another review!

@ezyang ezyang removed their request for review June 20, 2019 14:05
facebook-github-bot pushed a commit that referenced this pull request Jul 2, 2019
Summary:
# What is this?
This is an implementation of the AdamW optimizer as implemented in [the fastai library](https://github.com/fastai/fastai/blob/803894051bef32304ceea0c8ea5e04db64ff26b8/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training.

There have already been several abortive attempts to push this into pytorch in some form or fashion: #17468, #10866, #3740, #4429. Hopefully this one goes through.
# Why is this important?
Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.
# How was this tested?
There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation.
Pull Request resolved: #21250

Differential Revision: D16060339

Pulled By: vincentqb

fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
@vincentqb vincentqb mentioned this pull request Jul 2, 2019
xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
Summary:
# What is this?
This is an implementation of the AdamW optimizer as implemented in [the fastai library](https://github.com/fastai/fastai/blob/803894051bef32304ceea0c8ea5e04db64ff26b8/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training.

There have already been several abortive attempts to push this into pytorch in some form or fashion: pytorch#17468, pytorch#10866, pytorch#3740, pytorch#4429. Hopefully this one goes through.
# Why is this important?
Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.
# How was this tested?
There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation.
Pull Request resolved: pytorch#21250

Differential Revision: D16060339

Pulled By: vincentqb

fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
@kashif
Copy link
Contributor Author

kashif commented Jul 6, 2019

@soumith @vincentqb i see you have merged versions of this optimizer without reviewing my PR pending for more than a year, as you requested 17 days ago in the comment above. I find that disrespectful even if i understood the limited bandwidth you have but to then go ahead and accept copies of this work is unfair.

@kashif kashif closed this Jul 6, 2019
@loshchil
Copy link

loshchil commented Jul 6, 2019

@soumith @vincentqb
It seems that the new code does not take into account the implementation issue discussed in the current PR by @kashif (e.g., see my comment above and also here ). In might be necessary to make users aware that the implemented algorithm only partially follows the paper (it follows the main contribution of the paper that the weight decay update is taken out from the adaptive gradient update but the 'lr' and 'weight_decay' hyperparameters are still coupled in contrast to the paper where a scheduler is used).

I would like to thank you all and especially @kashif for investing your time to disseminate this work.

@soumith
Copy link
Contributor

soumith commented Jul 9, 2019

@kashif it was disrespectful, I fully agree -- sorry about that. It's also not a situation we wanted to be in -- where multiple PRs were pending for the same algorithm, and we had no maintainers for optim for like forever. If I could undo it one way or the other I would. I'm banking on @vincentqb being a maintainer who will over time do better (he is still ramping up, and I'm sure he doesn't have full context yet).

@vincentqb can you look at what @loshchil mentioned.

@kashif
Copy link
Contributor Author

kashif commented Jul 9, 2019

@soumith thank you for your kind message and i accept your apology and apologise for my outburst. I too will try to be more understanding. Sleeping over it I realised there was something else bothering me that I needed to work through. Anyways, let me know if I can help in anyway.

@vincentqb
Copy link
Contributor

vincentqb commented Jul 9, 2019

@kashif Thank you for your understanding. I appreciate your contributions -- and everyone else's in this pull request -- and I also want to apologize. I was in the process of going through this pull request.

The two key contributions in AdamW and SGDW seem to be:

  1. Weight decay done through the parameter theta update directly (instead of through the parameter gradient g: line 6 vs 9 in algorithm 1, line 6 vs 12 in algorithm 2),
  2. The definition of a new ScheduleMultiplier eta(t) replacing LRScheduler lr(t). In that case, the implied relation between ScheduleMultiplier and LRScheduler is lr(t) = lambda*alpha*eta(t), with lambda and alpha two constants.

Based on the discussions above, I am also understanding that we could update lambda and alpha independently instead of eta. Thus effectively having two separate schedulers, one for the momentum update and one for parameter update. This would be why 2 above is referred to as "decoupled weight decay". Is this what you mean @loshchil ?

If that is the case, my initial plan was the following. Given that we do not have a separate weight scheduler concept yet #22343 for the momentum update, the initial plan was to bring in AdamW and SGDW with "coupled" learning rate and weight decay (and thus using the same scheduler). However, based on the discussion above, the name AdamW and SGDW implies "decoupling" to the user, and not having the decoupling would bring confusion.

Thoughts?

@vincentqb vincentqb added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 6, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: optimizer Related to torch.optim open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.