Skip to content

Commit 7116f70

Browse files
jscanvictomMoralAndrewwango
authored
Remove ignored arguments (deepinv#640)
* CLN warn user about unused parameters in Physics * CLN move kwargs to physics * TST check that the warning is raised * FIX linting * black * get rid of ignored arguments * get rid of ignored arguments * get rid of ignored arguments * get rid of no-ops --------- Co-authored-by: tommoral <[email protected]> Co-authored-by: Andrewwango <[email protected]>
1 parent 8d08d86 commit 7116f70

File tree

9 files changed

+16
-36
lines changed

9 files changed

+16
-36
lines changed

deepinv/physics/blur.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,7 @@ def __init__(
447447
**kwargs,
448448
):
449449
super().__init__(**kwargs)
450-
self.method = "product_convolution2d"
451-
if self.method == "product_convolution2d":
452-
self.update_parameters(filters, multipliers, padding, **kwargs)
450+
self.update_parameters(filters, multipliers, padding, **kwargs)
453451
self.to(device)
454452

455453
def A(
@@ -468,14 +466,8 @@ def A(
468466
otherwise the blurred output has the same size as the image.
469467
:param str device: cpu or cuda
470468
"""
471-
if self.method == "product_convolution2d":
472-
self.update_parameters(filters, multipliers, padding, **kwargs)
473-
474-
return product_convolution2d(
475-
x, self.multipliers, self.filters, self.padding
476-
)
477-
else:
478-
raise NotImplementedError("Method not implemented in product-convolution")
469+
self.update_parameters(filters, multipliers, padding, **kwargs)
470+
return product_convolution2d(x, self.multipliers, self.filters, self.padding)
479471

480472
def A_adjoint(
481473
self, y: Tensor, filters=None, multipliers=None, padding=None, **kwargs
@@ -493,16 +485,12 @@ def A_adjoint(
493485
otherwise the blurred output has the same size as the image.
494486
:param str device: cpu or cuda
495487
"""
496-
if self.method == "product_convolution2d":
497-
self.update_parameters(
498-
filters=filters, multipliers=multipliers, padding=padding, **kwargs
499-
)
500-
501-
return product_convolution2d_adjoint(
502-
y, self.multipliers, self.filters, self.padding
503-
)
504-
else:
505-
raise NotImplementedError("Method not implemented in product-convolution")
488+
self.update_parameters(
489+
filters=filters, multipliers=multipliers, padding=padding, **kwargs
490+
)
491+
return product_convolution2d_adjoint(
492+
y, self.multipliers, self.filters, self.padding
493+
)
506494

507495
def update_parameters(
508496
self,

deepinv/physics/phase_retrieval.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ def __init__(
172172
img_size=img_size,
173173
fast=False,
174174
channelwise=channelwise,
175-
unitary=unitary,
176-
compute_inverse=compute_inverse,
177175
dtype=dtype,
178176
device=device,
179177
rng=self.rng,
@@ -291,7 +289,6 @@ def __init__(
291289
B = StructuredRandom(
292290
img_size=self.img_size,
293291
output_size=self.output_size,
294-
mode=self.mode,
295292
n_layers=self.n_layers,
296293
transform_func=transform_func,
297294
transform_func_inv=transform_func_inv,

deepinv/tests/test_deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_deprecated_physics_image_size():
2727
# CS: img_shape is changed to img_size
2828
with pytest.warns(DeprecationWarning, match="img_shape.*deprecated"):
2929
p = dinv.physics.CompressedSensing(
30-
m=m, img_shape=img_size, device="cpu", compute_inverse=True, rng=rng
30+
m=m, img_shape=img_size, device="cpu", rng=rng
3131
)
3232
assert p.img_size == img_size
3333

deepinv/tests/test_physics.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def find_operator(name, device, get_physics_param=False):
112112
if name == "CS":
113113
m = 30
114114
p = dinv.physics.CompressedSensing(
115-
m=m, img_size=img_size, device=device, compute_inverse=True, rng=rng
115+
m=m, img_size=img_size, device=device, rng=rng
116116
)
117117
norm = (
118118
1 + np.sqrt(np.prod(img_size) / m)
@@ -355,7 +355,6 @@ def find_operator(name, device, get_physics_param=False):
355355
padding=padding,
356356
device=device,
357357
filter="bilinear",
358-
dtype=dtype,
359358
)
360359
params = ["filter"]
361360
elif name == "complex_compressed_sensing":
@@ -366,7 +365,6 @@ def find_operator(name, device, get_physics_param=False):
366365
img_size=img_size,
367366
dtype=torch.cdouble,
368367
device=device,
369-
compute_inverse=True,
370368
rng=rng,
371369
)
372370
dtype = p.dtype
@@ -407,7 +405,6 @@ def find_operator(name, device, get_physics_param=False):
407405
samples_loc=uv.permute((1, 0)),
408406
dataWeight=dataWeight,
409407
real_projection=False,
410-
dtype=torch.float,
411408
device=device,
412409
noise_model=dinv.physics.GaussianNoise(0.0, rng=rng),
413410
)
@@ -942,7 +939,7 @@ def test_noise(device, noise_type):
942939
r"""
943940
Tests noise models.
944941
"""
945-
physics = dinv.physics.DecomposablePhysics(device=device)
942+
physics = dinv.physics.DecomposablePhysics()
946943
physics.noise_model = choose_noise(noise_type, device)
947944
x = torch.ones((1, 3, 2), device=device).unsqueeze(0)
948945

@@ -988,7 +985,6 @@ def test_blur(device):
988985
h = torch.ones((1, 1, 5, 5)) / 25.0
989986

990987
physics_blur = dinv.physics.Blur(
991-
img_size=(1, x.shape[-2], x.shape[-1]),
992988
filter=h,
993989
device=device,
994990
padding="circular",

deepinv/tests/test_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_algo(name_algo, device):
112112
test_sample = torch.ones((1, 3, 64, 64), device=device)
113113

114114
sigma = 1
115-
physics = dinv.physics.Denoising(device=device)
115+
physics = dinv.physics.Denoising()
116116
physics.noise_model = dinv.physics.GaussianNoise(sigma)
117117
y = physics(test_sample)
118118

examples/basics/demo_blur_tour.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@
292292
)
293293
params_pc = pc_generator.step(batch_size)
294294

295-
physics = SpaceVaryingBlur(method="product_convolution2d", **params_pc)
295+
physics = SpaceVaryingBlur(**params_pc)
296296

297297
dirac_comb = torch.zeros(img_size)[None, None]
298298
dirac_comb[0, 0, ::delta, ::delta] = 1

examples/basics/demo_physics_tour.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@
9393
fast=False,
9494
channelwise=True,
9595
img_size=img_size,
96-
compute_inverse=True,
9796
device=device,
9897
)
9998

examples/external-libraries/demo_ri_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def to_logimage(im, rescale=False, dr=5000):
206206
physics = RadioInterferometry(
207207
img_size=image_gdth.shape[-2:],
208208
samples_loc=uv.permute((1, 0)),
209-
real=True,
209+
real_projection=True,
210210
device=device,
211211
)
212212

examples/patch-priors/demo_epll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
sigma = 0.1
3838
noise_model = GaussianNoise(sigma)
39-
physics = Denoising(device=device, noise_model=noise_model)
39+
physics = Denoising(noise_model=noise_model)
4040
observation = physics(test_img)
4141

4242
# %%

0 commit comments

Comments
 (0)