This repository was archived by the owner on Mar 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscenario2.py
More file actions
84 lines (60 loc) · 2.86 KB
/
scenario2.py
File metadata and controls
84 lines (60 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Scenario:
2. Updated default Hyper-params (BC-breaking)
Description:
The default hyper-param of a class/model/layer needs to be changed (BC-breaking) because it causes issues to the
users. All other code remains the same. New models need to be constructed with the updated value, old pre-trained
models must continue using the original value.
Example:
https://github.com/pytorch/vision/issues/2599
https://github.com/pytorch/vision/pull/2933
https://github.com/pytorch/vision/pull/2940
The original value of eps of the `FrozenBatchNorm2d` was `0.0` and was causing training stability problems.
We considered it a bug and thus we BC-broke by updating the default value to `1e-5` in the class. Nevertheless
previously trained models had to continue using `0.0`. To resolve it we introduced the method
`torchvision.models.detection._utils.overwrite_eps()` to overwrite the epsilon values of all FrozenBN layers
after they have been created.
Here we propose an alternative mechanism which allows to overwrite the default values during object construction
using Context Managers.
"""
from torch import nn, Tensor
from typing import Optional
from dapi_lib.models._api import register, ContextParams, Weights, WeightEntry
# Import a few stuff that we plan to keep as-is to avoid copy-pasting
from torchvision.ops.misc import FrozenBatchNorm2d
__all__ = ['Dummy']
# The only reason why we inherit instead of making the changes directly to FrozenBatchNorm2d is to avoid copy-pasting
# a lot of code from TorchVision. The changes below should happen on the parent class.
class MyFrozenBN(FrozenBatchNorm2d):
def __init__(self, num_features: int, eps: float = 1e-5):
# The ContextParams.get() is used to overwrite the default value of the constructor under specific conditions.
super().__init__(num_features, eps=ContextParams.get(self, 'eps', eps))
class Dummy(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(32, 64, 1)
self.bn = MyFrozenBN(64)
self.act = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
return self.act(self.bn(self.conv(x)))
class DummyWeights(Weights):
DUMMY = WeightEntry(
'https://fake/models/dummy_weights.pth',
None,
{},
True
)
@register
def dummy(weights: Optional[DummyWeights] = None) -> Dummy:
DummyWeights.check_type(weights)
# Overwrites the default epsilon only when the weights parameter is specified
with ContextParams(MyFrozenBN, weights is not None, eps=0.0):
model = Dummy()
if weights is not None and 'fake' not in weights.url:
model.load_state_dict(weights.state_dict(progress=False))
return model
if __name__ == "__main__":
m = dummy(weights=DummyWeights.DUMMY)
assert m.bn.eps == 0.0
m = dummy()
assert m.bn.eps == 1e-5