Skip to content

Commit 211ad23

Browse files
iramazanlifacebook-github-bot
authored andcommitted
To add state_dict and load_state_dict to SequentialLR (#65035)
Summary: To add state_dict() and load_state_dict() methods to SequentialLR Pull Request resolved: #65035 Reviewed By: prabhat00155, nateanl Differential Revision: D30958204 Pulled By: datumbox fbshipit-source-id: 65114e1b07146526ae2680233f5cd42b2534d67a
1 parent 8a652e0 commit 211ad23

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

torch/optim/lr_scheduler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,37 @@ def step(self):
632632
else:
633633
self._schedulers[idx].step()
634634

635+
def state_dict(self):
636+
"""Returns the state of the scheduler as a :class:`dict`.
637+
638+
It contains an entry for every variable in self.__dict__ which
639+
is not the optimizer.
640+
The wrapped scheduler states will also be saved.
641+
"""
642+
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
643+
state_dict['_schedulers'] = [None] * len(self._schedulers)
644+
645+
for idx, s in enumerate(self._schedulers):
646+
state_dict['_schedulers'][idx] = s.state_dict()
647+
648+
return state_dict
649+
650+
def load_state_dict(self, state_dict):
651+
"""Loads the schedulers state.
652+
653+
Args:
654+
state_dict (dict): scheduler state. Should be an object returned
655+
from a call to :meth:`state_dict`.
656+
"""
657+
_schedulers = state_dict.pop('_schedulers')
658+
self.__dict__.update(state_dict)
659+
# Restore state_dict keys in order to prevent side effects
660+
# https://github.com/pytorch/pytorch/issues/32756
661+
state_dict['_schedulers'] = _schedulers
662+
663+
for idx, s in enumerate(_schedulers):
664+
self._schedulers[idx].load_state_dict(s)
665+
635666

636667
class CosineAnnealingLR(_LRScheduler):
637668
r"""Set the learning rate of each parameter group using a cosine annealing

0 commit comments

Comments
 (0)