Skip to content

Commit 338ddb5

Browse files
committed
add tests upsampling
1 parent 1a15f57 commit 338ddb5

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

deepinv/physics/blur.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,33 @@ def update_parameters(self, filter=None, factor=None, **kwargs):
203203
class Upsampling(Downsampling):
204204
r"""
205205
Upsampling operator.
206+
207+
This operator performs the operation
208+
209+
.. math::
210+
y = h^T * S^T (x)
211+
212+
where :math:`S^T` is the adjoint of the subsampling operator and :math:`h` is a low-pass filter.
213+
214+
:param torch.Tensor, str, None filter: Upsampling filter. It can be ``'gaussian'``, ``'bilinear'``, ``'bicubic'``
215+
, ``'sinc'`` or a custom ``torch.Tensor`` filter. If ``None``, no filtering is applied.
216+
:param tuple[int] img_size: size of the output image
217+
:param int factor: upsampling factor
218+
:param str padding: options are ``'circular'``, ``'replicate'`` and ``'reflect'``.
219+
:param str device: cpu or cuda
206220
"""
221+
def __init__(self, img_size, filter=None, factor=2, padding="circular", device="cpu", **kwargs):
222+
223+
assert padding != "valid", "Padding 'valid' is not supported for Upsampling operator."
224+
225+
super().__init__(
226+
img_size=img_size,
227+
filter=filter,
228+
factor=factor,
229+
padding=padding,
230+
device=device,
231+
**kwargs,
232+
)
207233

208234
def A(self, x, **kwargs):
209235
return super().A_adjoint(x, **kwargs)

deepinv/tests/test_physics.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
"super_resolution_reflect",
4646
"super_resolution_replicate",
4747
"super_resolution_constant",
48+
"down_resolution_circular",
49+
"down_resolution_reflect",
50+
"down_resolution_replicate",
51+
"down_resolution_constant",
52+
"down_resolution_valid",
4853
"aliased_super_resolution",
4954
"fast_singlepixel",
5055
"fast_singlepixel_cake_cutting",
@@ -375,6 +380,19 @@ def find_operator(name, device, imsize=None, get_physics_param=False, wrapper=No
375380
dtype=dtype,
376381
)
377382
params = ["filter"]
383+
elif name.startswith("down_resolution"):
384+
img_size = (1, 32, 32) if imsize is None else imsize
385+
factor = 2
386+
norm = 1.0 / factor**2
387+
p = dinv.physics.Upsampling(
388+
img_size=(img_size[0], img_size[1]*factor, img_size[2]*factor),
389+
factor=factor,
390+
padding=padding,
391+
device=device,
392+
filter="bilinear",
393+
dtype=dtype,
394+
)
395+
params = ["filter"]
378396
elif name == "complex_compressed_sensing":
379397
img_size = (1, 8, 8) if imsize is None else imsize
380398
m = 50
@@ -601,6 +619,54 @@ def test_operators_adjointness(name, device, rng):
601619

602620
assert error2 < 1e-3
603621

622+
def test_upsampling(device, rng):
623+
r"""
624+
This function tests that the Upsampling and Downsampling operators are effectively adjoint to each other.
625+
626+
Note that the test does not hold when the padding is not 'valid', as the Upsampling operator
627+
does not support 'valid' padding.
628+
"""
629+
630+
list_ops = ["down_resolution_circular",
631+
"down_resolution_reflect",
632+
"down_resolution_replicate",
633+
"down_resolution_constant"]
634+
635+
for kernel in ["bilinear", "bicubic", "sinc", "gaussian"]:
636+
for name in list_ops:
637+
padding = name.split("_")[-1] # get padding type from name
638+
physics, imsize, _, dtype = find_operator(name, device)
639+
physics_adjoint, _, _, dtype = find_operator("super_resolution_"+padding, device, imsize=imsize)
640+
641+
642+
# physics.register_buffer("filter", None)
643+
physics.update_parameters(filter=kernel)
644+
645+
# physics_adjoint.register_buffer("filter", None)
646+
physics_adjoint.update_parameters(filter=kernel)
647+
648+
factor = physics.factor
649+
650+
x = torch.randn((1, imsize[0], imsize[1], imsize[2]), device=device, dtype=dtype, generator=rng)
651+
652+
out = physics(x)
653+
assert out.shape == (1, imsize[0], imsize[1] * factor, imsize[2] * factor)
654+
655+
y = physics(x)
656+
err1 = (physics.A_adjoint(y) - physics_adjoint(y)).flatten().mean().abs()
657+
assert err1 < 1e-6
658+
659+
imsize_new = (*imsize[:1], imsize[1] * factor, imsize[2] * factor)
660+
physics_adjoint, _, _, dtype = find_operator("super_resolution_"+padding, device, imsize=imsize_new) # we need to redefine the adjoint operator with the new image size
661+
662+
# physics_adjoint.register_buffer("filter", None)
663+
physics_adjoint.update_parameters(filter=kernel)
664+
665+
x = torch.randn(imsize_new, device=device, dtype=dtype, generator=rng).unsqueeze(0)
666+
y = physics_adjoint(x)
667+
err2 = (physics.A(y) - physics_adjoint.A_adjoint(y)).flatten().mean().abs()
668+
assert err2 < 1e-6
669+
604670

605671
@pytest.mark.parametrize("name", OPERATORS)
606672
@pytest.mark.parametrize("wrapper", WRAPPERS)

0 commit comments

Comments
 (0)