-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
According to the documentation, decay is a number in [0,1] range, i.e.
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999.
Neither in the SWA paper (e.g., https://arxiv.org/pdf/2310.04415), other related papers (e.g., https://arxiv.org/abs/1803.05407), negative values are not considered/make little sense. It would be beneficial to add a check for such invalid values so unexpected behavior is not possible -- will open a PR that links to this issue.
Examples of non-checked decay are in https://github.com/pytorch/pytorch/blob/main/torch/optim/swa_utils.py.
Versions
This issue is related to formulation of SWA algo in current master (Aug 17th, 8pm CEST).