Skip to content

Commit 99526f0

Browse files
authored
Fix for Issue deepinv#647 – UNet uses zero-padding even when it shouldn't (deepinv#653)
* fix bug unet padding mode * black unet file * change log
1 parent 3569ab3 commit 99526f0

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Changed
1616
Fixed
1717
^^^^^
1818
- Fix memory leak in `deepinv.physics.tomography` when using autograd (:gh:`651` by `Minh Hai Nguyen`_)
19-
19+
- Fix the circular padded UNet (:gh:`653` by `Victor Sechaud`_)
2020

2121
v0.3.2
2222
------

deepinv/models/unet.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,13 @@ def conv_block(ch_in, ch_out):
116116
),
117117
nn.ReLU(inplace=True),
118118
nn.Conv2d(
119-
ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
119+
ch_out,
120+
ch_out,
121+
kernel_size=3,
122+
stride=1,
123+
padding=1,
124+
bias=bias,
125+
padding_mode="circular" if circular_padding else "zeros",
120126
),
121127
(
122128
BFBatchNorm2d(ch_out, use_bias=bias)
@@ -138,7 +144,13 @@ def conv_block(ch_in, ch_out):
138144
),
139145
nn.ReLU(inplace=True),
140146
nn.Conv2d(
141-
ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
147+
ch_out,
148+
ch_out,
149+
kernel_size=3,
150+
stride=1,
151+
padding=1,
152+
bias=bias,
153+
padding_mode="circular" if circular_padding else "zeros",
142154
),
143155
nn.ReLU(inplace=True),
144156
)
@@ -148,7 +160,13 @@ def up_conv(ch_in, ch_out):
148160
return nn.Sequential(
149161
nn.Upsample(scale_factor=2),
150162
nn.Conv2d(
151-
ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
163+
ch_in,
164+
ch_out,
165+
kernel_size=3,
166+
stride=1,
167+
padding=1,
168+
bias=bias,
169+
padding_mode="circular" if circular_padding else "zeros",
152170
),
153171
(
154172
BFBatchNorm2d(ch_out, use_bias=bias)
@@ -161,7 +179,13 @@ def up_conv(ch_in, ch_out):
161179
return nn.Sequential(
162180
nn.Upsample(scale_factor=2),
163181
nn.Conv2d(
164-
ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
182+
ch_in,
183+
ch_out,
184+
kernel_size=3,
185+
stride=1,
186+
padding=1,
187+
bias=bias,
188+
padding_mode="circular" if circular_padding else "zeros",
165189
),
166190
nn.ReLU(inplace=True),
167191
)

0 commit comments

Comments
 (0)