Skip to content

Commit 2108097

Browse files
Yanghan Wangfacebook-github-bot
authored andcommitted
update to use torch.optim.lr_scheduler.LRScheduler
Summary: Pull Request resolved: #4709 pytorch/pytorch#88503 introduces the public version `LRScheduler`, however `isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler)` doesn't work anymore because of https://github.com/pytorch/pytorch/blob/1ea11ecb2eea99eb552603b7cf5fbfc59659832d/torch/optim/lr_scheduler.py#L166-L169. It's a bit tricky to make it BC compatible for torch version <= 1.13. V1 of this diff uses try catch block to import the `LRScheduler` and make it available in `detectron2.solver`, then the whole D2 (11528ce) uses this version of `LRScheduler`. There're two drawbacks though: - it adds a little mental burden to figure out what's D2 (11528ce083dc9ff83ee3a8f9086a1ef54d2a402f)'s `LRScheduler`, previously it's clear that the `LRScheduler`/`_LRScheduler` is from `torch`. - it has a name collision with `hooks.LRScheduler`, eg. in the `hooks.py` I have to do `LRScheduler as _LRScheduler`. But I couldn't found a better solution, maybe use try catch block in every file? Reviewed By: sstsai-adl Differential Revision: D42111273 fbshipit-source-id: 0269127de1ba3ef90225c5dfe085bb209f6cf4d1
1 parent cd79703 commit 2108097

File tree

6 files changed

+25
-16
lines changed

6 files changed

+25
-16
lines changed

detectron2/engine/hooks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import detectron2.utils.comm as comm
2222
from detectron2.evaluation.testing import flatten_results_dict
2323
from detectron2.solver import LRMultiplier
24+
from detectron2.solver import LRScheduler as _LRScheduler
2425
from detectron2.utils.events import EventStorage, EventWriter
2526
from detectron2.utils.file_io import PathManager
2627

@@ -362,12 +363,12 @@ def scheduler(self):
362363
return self._scheduler or self.trainer.scheduler
363364

364365
def state_dict(self):
365-
if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler):
366+
if isinstance(self.scheduler, _LRScheduler):
366367
return self.scheduler.state_dict()
367368
return {}
368369

369370
def load_state_dict(self, state_dict):
370-
if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler):
371+
if isinstance(self.scheduler, _LRScheduler):
371372
logger = logging.getLogger(__name__)
372373
logger.info("Loading scheduler from state_dict ...")
373374
self.scheduler.load_state_dict(state_dict)

detectron2/solver/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params
3-
from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR, LRMultiplier, WarmupParamScheduler
3+
from .lr_scheduler import (
4+
LRMultiplier,
5+
LRScheduler,
6+
WarmupCosineLR,
7+
WarmupMultiStepLR,
8+
WarmupParamScheduler,
9+
)
410

511
__all__ = [k for k in globals().keys() if not k.startswith("_")]

detectron2/solver/build.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from detectron2.config import CfgNode
1616
from detectron2.utils.env import TORCH_VERSION
1717

18-
from .lr_scheduler import LRMultiplier, WarmupParamScheduler
18+
from .lr_scheduler import LRMultiplier, LRScheduler, WarmupParamScheduler
1919

2020
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
2121
_GradientClipper = Callable[[_GradientClipperInput], None]
@@ -267,9 +267,7 @@ def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
267267
return ret
268268

269269

270-
def build_lr_scheduler(
271-
cfg: CfgNode, optimizer: torch.optim.Optimizer
272-
) -> torch.optim.lr_scheduler._LRScheduler:
270+
def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler:
273271
"""
274272
Build a LR scheduler from config.
275273
"""

detectron2/solver/lr_scheduler.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
ParamScheduler,
1212
)
1313

14+
try:
15+
from torch.optim.lr_scheduler import LRScheduler
16+
except ImportError:
17+
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
18+
1419
logger = logging.getLogger(__name__)
1520

1621

@@ -52,7 +57,7 @@ def __init__(
5257
)
5358

5459

55-
class LRMultiplier(torch.optim.lr_scheduler._LRScheduler):
60+
class LRMultiplier(LRScheduler):
5661
"""
5762
A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the
5863
learning rate of each param in the optimizer.
@@ -95,7 +100,7 @@ def __init__(
95100
):
96101
"""
97102
Args:
98-
optimizer, last_iter: See ``torch.optim.lr_scheduler._LRScheduler``.
103+
optimizer, last_iter: See ``torch.optim.lr_scheduler.LRScheduler``.
99104
``last_iter`` is the same as ``last_epoch``.
100105
multiplier: a fvcore ParamScheduler that defines the multiplier on
101106
every LR of the optimizer
@@ -132,7 +137,7 @@ def get_lr(self) -> List[float]:
132137
# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
133138

134139

135-
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
140+
class WarmupMultiStepLR(LRScheduler):
136141
def __init__(
137142
self,
138143
optimizer: torch.optim.Optimizer,
@@ -171,7 +176,7 @@ def _compute_values(self) -> List[float]:
171176
return self.get_lr()
172177

173178

174-
class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler):
179+
class WarmupCosineLR(LRScheduler):
175180
def __init__(
176181
self,
177182
optimizer: torch.optim.Optimizer,

projects/DeepLab/deeplab/build_solver.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import torch
33

44
from detectron2.config import CfgNode
5+
from detectron2.solver import LRScheduler
56
from detectron2.solver import build_lr_scheduler as build_d2_lr_scheduler
67

78
from .lr_scheduler import WarmupPolyLR
89

910

10-
def build_lr_scheduler(
11-
cfg: CfgNode, optimizer: torch.optim.Optimizer
12-
) -> torch.optim.lr_scheduler._LRScheduler:
11+
def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler:
1312
"""
1413
Build a LR scheduler from config.
1514
"""

projects/DeepLab/deeplab/lr_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List
44
import torch
55

6-
from detectron2.solver.lr_scheduler import _get_warmup_factor_at_iter
6+
from detectron2.solver.lr_scheduler import LRScheduler, _get_warmup_factor_at_iter
77

88
# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
99
# only on epoch boundaries. We typically use iteration based schedules instead.
@@ -14,7 +14,7 @@
1414
# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
1515

1616

17-
class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
17+
class WarmupPolyLR(LRScheduler):
1818
"""
1919
Poly learning rate schedule used to train DeepLab.
2020
Paper: DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,

0 commit comments

Comments
 (0)