Skip to content

Commit 0853fe8

Browse files
committed
comments
1 parent f815630 commit 0853fe8

File tree

3 files changed

+50
-51
lines changed

3 files changed

+50
-51
lines changed

deepinv/physics/wrappers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ class PhysicsMultiScaler(Physics):
1717
1818
where :math:`U_{scale}` is the upsampling operator for the given scale and :math:`A_{base}` is the base physics operator.
1919
20-
By default, we assume that the factors for the different scales are [2, 4, 8]. The 0th scale corresponds to no upsampling,
21-
the 1st scale corresponds to upsampling by a factor of 2, the 2nd scale corresponds to upsampling by a factor of 4, and so on.
20+
By default, we assume that the factors for the different scales are [2, 4, 8].
21+
The 1st scale corresponds to upsampling by a factor of 2, the 2nd scale corresponds to upsampling by a factor of 4, and so on.
22+
The 0th scale corresponds to the base physics operator without upsampling.
2223
2324
:param deepinv.physics.Physics physics: base physics operator.
2425
:param tuple img_shape: shape of the input image (C, H, W).
@@ -123,7 +124,7 @@ class PhysicsCropper(LinearPhysics):
123124
The adjoint operator is defined as :math:`\tilde{A}^{\top} = C^{\top} \circ A^{\top}` and :math:`C^{\top}` is a padding operator that pads the input tensor to the original size.
124125
125126
:param deepinv.physics.LinearPhysics physics: base linear physics operator.
126-
:param tuple pad: padding to apply to the input tensor, e.g., (pad_height, pad_width).
127+
:param tuple crop: padding to apply to the input tensor, e.g., (pad_height, pad_width).
127128
"""
128129

129130
def __init__(self, physics, crop, dtype=None):

deepinv/tests/test_physics.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,11 @@ def find_operator(name, device, imsize=None, get_physics_param=False):
423423

424424
# Reshape to [nb_points x 2]
425425
uv = uv.view(-1, 2)
426-
uv = uv.to(device, dtype=dtype)
426+
uv = uv.to(device)
427427

428428
if "weighted" in name:
429429
dataWeight = torch.linspace(
430-
0.01, 0.99, uv.shape[0], device=device, dtype=dtype
430+
0.01, 0.99, uv.shape[0], device=device
431431
) # take a non-trivial weight
432432
else:
433433
dataWeight = torch.tensor(
@@ -616,65 +616,63 @@ def test_operators_adjointness(name, device, rng):
616616
assert error2 < 1e-3
617617

618618

619-
def test_upsampling(device, rng):
619+
LIST_DOWN_OP = [
620+
"down_resolution_circular",
621+
"down_resolution_reflect",
622+
"down_resolution_replicate",
623+
"down_resolution_constant",
624+
]
625+
626+
627+
@pytest.mark.parametrize("name", LIST_DOWN_OP)
628+
@pytest.mark.parametrize("kernel", ["bilinear", "bicubic", "sinc", "gaussian"])
629+
def test_upsampling(device, rng, name, kernel):
620630
r"""
621631
This function tests that the Upsampling and Downsampling operators are effectively adjoint to each other.
622632
623633
Note that the test does not hold when the padding is not 'valid', as the Upsampling operator
624634
does not support 'valid' padding.
625635
"""
636+
padding = name.split("_")[-1] # get padding type from name
637+
physics, imsize, _, dtype = find_operator(name, device)
638+
physics_adjoint, _, _, dtype = find_operator(
639+
"super_resolution_" + padding, device, imsize=imsize
640+
)
626641

627-
list_ops = [
628-
"down_resolution_circular",
629-
"down_resolution_reflect",
630-
"down_resolution_replicate",
631-
"down_resolution_constant",
632-
]
633-
634-
for kernel in ["bilinear", "bicubic", "sinc", "gaussian"]:
635-
for name in list_ops:
636-
padding = name.split("_")[-1] # get padding type from name
637-
physics, imsize, _, dtype = find_operator(name, device)
638-
physics_adjoint, _, _, dtype = find_operator(
639-
"super_resolution_" + padding, device, imsize=imsize
640-
)
641-
642-
# physics.register_buffer("filter", None)
643-
physics.update_parameters(filter=kernel)
642+
# physics.register_buffer("filter", None)
643+
physics.update_parameters(filter=kernel)
644644

645-
# physics_adjoint.register_buffer("filter", None)
646-
physics_adjoint.update_parameters(filter=kernel)
645+
# physics_adjoint.register_buffer("filter", None)
646+
physics_adjoint.update_parameters(filter=kernel)
647647

648-
factor = physics.factor
648+
factor = physics.factor
649649

650-
x = torch.randn(
651-
(1, imsize[0], imsize[1], imsize[2]),
652-
device=device,
653-
dtype=dtype,
654-
generator=rng,
655-
)
650+
x = torch.randn(
651+
(1, imsize[0], imsize[1], imsize[2]),
652+
device=device,
653+
dtype=dtype,
654+
generator=rng,
655+
)
656656

657-
out = physics(x)
658-
assert out.shape == (1, imsize[0], imsize[1] * factor, imsize[2] * factor)
657+
out = physics(x)
658+
assert out.shape == (1, imsize[0], imsize[1] * factor, imsize[2] * factor)
659659

660-
y = physics(x)
661-
err1 = (physics.A_adjoint(y) - physics_adjoint(y)).flatten().mean().abs()
662-
assert err1 < 1e-6
660+
y = physics(x)
661+
err1 = (physics.A_adjoint(y) - physics_adjoint(y)).flatten().mean().abs()
662+
assert err1 < 1e-6
663663

664-
imsize_new = (*imsize[:1], imsize[1] * factor, imsize[2] * factor)
665-
physics_adjoint, _, _, dtype = find_operator(
666-
"super_resolution_" + padding, device, imsize=imsize_new
667-
) # we need to redefine the adjoint operator with the new image size
664+
imsize_new = (*imsize[:1], imsize[1] * factor, imsize[2] * factor)
665+
physics_adjoint, _, _, dtype = find_operator(
666+
"super_resolution_" + padding, device, imsize=imsize_new
667+
) # we need to redefine the adjoint operator with the new image size
668668

669-
# physics_adjoint.register_buffer("filter", None)
670-
physics_adjoint.update_parameters(filter=kernel)
669+
# physics_adjoint.register_buffer("filter", None)
670+
physics_adjoint.update_parameters(filter=kernel)
671671

672-
x = torch.randn(
673-
imsize_new, device=device, dtype=dtype, generator=rng
674-
).unsqueeze(0)
675-
y = physics_adjoint(x)
676-
err2 = (physics.A(y) - physics_adjoint.A_adjoint(y)).flatten().mean().abs()
677-
assert err2 < 1e-6
672+
x = torch.randn(imsize_new, device=device, dtype=dtype, generator=rng).unsqueeze(0)
673+
y = physics_adjoint(x)
674+
err2 = (physics.A(y) - physics_adjoint.A_adjoint(y)).flatten().mean().abs()
675+
assert err2 < 1e-6
678676

679677

680678
@pytest.mark.parametrize("name", OPERATORS)

docs/source/user_guide/physics/physics.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,12 @@ Wrappers are operators that can be used to adapt existing operators to a new pro
123123
* - **Family**
124124
- **Operators**
125125

126-
* - Multicale
126+
* - Multiscale
127127
-
128128
| :class:`deepinv.physics.PhysicsMultiScaler`
129129
| :class:`deepinv.physics.LinearPhysicsMultiScaler`
130130
131-
* - Padding
131+
* - Padding/Cropping
132132
-
133133
| :class:`deepinv.physics.PhysicsCropper`
134134

0 commit comments

Comments
 (0)