Skip to content

Commit ce80523

Browse files
committed
To fix the chainability at epoch zero for some schedulers
1 parent 4a390a5 commit ce80523

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

test/test_optim.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,8 @@ def test_adam(self):
440440
)
441441
self._test_basic_cases(
442442
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
443-
[lambda opt: ExponentialLR(opt, gamma=0.9),
444-
lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")]
443+
[lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant"),
444+
lambda opt: ExponentialLR(opt, gamma=0.9)]
445445
)
446446
self._test_basic_cases(
447447
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
@@ -1294,8 +1294,8 @@ def test_compound_exp_and_linear_warmup_lr(self):
12941294
for i in range(iters):
12951295
single_targets[i] *= factor + i / iters * (1 - factor)
12961296
targets = [single_targets, [x * epochs for x in single_targets]]
1297-
schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
1298-
schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
1297+
schedulers[0] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
1298+
schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
12991299
self._test(schedulers, targets, epochs)
13001300

13011301
def test_compound_step_and_constant_warmup(self):
@@ -1361,8 +1361,8 @@ def test_compound_cosanneal_and_linear_warmup_lr(self):
13611361
for i in range(iters):
13621362
single_targets[i] *= factor + i / iters * (1 - factor)
13631363
targets = [single_targets, [x * epochs for x in single_targets]]
1364-
schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
1365-
schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
1364+
schedulers[0] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
1365+
schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
13661366
self._test(schedulers, targets, epochs)
13671367

13681368
def test_compound_cosanneal_and_exp_lr(self):

torch/optim/lr_scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def get_lr(self):
328328
return [group['lr'] * lmbda(self.last_epoch)
329329
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
330330
else:
331-
return list(self.base_lrs)
331+
return [group['lr'] for group in self.optimizer.param_groups]
332332

333333

334334
class StepLR(_LRScheduler):
@@ -526,7 +526,7 @@ def get_lr(self):
526526
"please use `get_last_lr()`.", UserWarning)
527527

528528
if self.last_epoch == 0:
529-
return self.base_lrs
529+
return [group['lr'] for group in self.optimizer.param_groups]
530530
return [group['lr'] * self.gamma
531531
for group in self.optimizer.param_groups]
532532

@@ -586,7 +586,7 @@ def get_lr(self):
586586
"please use `get_last_lr()`.", UserWarning)
587587

588588
if self.last_epoch == 0:
589-
return self.base_lrs
589+
return [group['lr'] for group in self.optimizer.param_groups]
590590
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
591591
return [group['lr'] + (base_lr - self.eta_min) *
592592
(1 - math.cos(math.pi / self.T_max)) / 2

0 commit comments

Comments
 (0)