|
45 | 45 | "super_resolution_reflect", |
46 | 46 | "super_resolution_replicate", |
47 | 47 | "super_resolution_constant", |
| 48 | + "down_resolution_circular", |
| 49 | + "down_resolution_reflect", |
| 50 | + "down_resolution_replicate", |
| 51 | + "down_resolution_constant", |
| 52 | + "down_resolution_valid", |
48 | 53 | "aliased_super_resolution", |
49 | 54 | "fast_singlepixel", |
50 | 55 | "fast_singlepixel_cake_cutting", |
@@ -375,6 +380,19 @@ def find_operator(name, device, imsize=None, get_physics_param=False, wrapper=No |
375 | 380 | dtype=dtype, |
376 | 381 | ) |
377 | 382 | params = ["filter"] |
| 383 | + elif name.startswith("down_resolution"): |
| 384 | + img_size = (1, 32, 32) if imsize is None else imsize |
| 385 | + factor = 2 |
| 386 | + norm = 1.0 / factor**2 |
| 387 | + p = dinv.physics.Upsampling( |
| 388 | + img_size=(img_size[0], img_size[1]*factor, img_size[2]*factor), |
| 389 | + factor=factor, |
| 390 | + padding=padding, |
| 391 | + device=device, |
| 392 | + filter="bilinear", |
| 393 | + dtype=dtype, |
| 394 | + ) |
| 395 | + params = ["filter"] |
378 | 396 | elif name == "complex_compressed_sensing": |
379 | 397 | img_size = (1, 8, 8) if imsize is None else imsize |
380 | 398 | m = 50 |
@@ -601,6 +619,54 @@ def test_operators_adjointness(name, device, rng): |
601 | 619 |
|
602 | 620 | assert error2 < 1e-3 |
603 | 621 |
|
| 622 | +def test_upsampling(device, rng): |
| 623 | + r""" |
| 624 | + This function tests that the Upsampling and Downsampling operators are effectively adjoint to each other. |
| 625 | +
|
| 626 | + Note that the test does not hold when the padding is not 'valid', as the Upsampling operator |
| 627 | + does not support 'valid' padding. |
| 628 | + """ |
| 629 | + |
| 630 | + list_ops = ["down_resolution_circular", |
| 631 | + "down_resolution_reflect", |
| 632 | + "down_resolution_replicate", |
| 633 | + "down_resolution_constant"] |
| 634 | + |
| 635 | + for kernel in ["bilinear", "bicubic", "sinc", "gaussian"]: |
| 636 | + for name in list_ops: |
| 637 | + padding = name.split("_")[-1] # get padding type from name |
| 638 | + physics, imsize, _, dtype = find_operator(name, device) |
| 639 | + physics_adjoint, _, _, dtype = find_operator("super_resolution_"+padding, device, imsize=imsize) |
| 640 | + |
| 641 | + |
| 642 | + # physics.register_buffer("filter", None) |
| 643 | + physics.update_parameters(filter=kernel) |
| 644 | + |
| 645 | + # physics_adjoint.register_buffer("filter", None) |
| 646 | + physics_adjoint.update_parameters(filter=kernel) |
| 647 | + |
| 648 | + factor = physics.factor |
| 649 | + |
| 650 | + x = torch.randn((1, imsize[0], imsize[1], imsize[2]), device=device, dtype=dtype, generator=rng) |
| 651 | + |
| 652 | + out = physics(x) |
| 653 | + assert out.shape == (1, imsize[0], imsize[1] * factor, imsize[2] * factor) |
| 654 | + |
| 655 | + y = physics(x) |
| 656 | + err1 = (physics.A_adjoint(y) - physics_adjoint(y)).flatten().mean().abs() |
| 657 | + assert err1 < 1e-6 |
| 658 | + |
| 659 | + imsize_new = (*imsize[:1], imsize[1] * factor, imsize[2] * factor) |
| 660 | + physics_adjoint, _, _, dtype = find_operator("super_resolution_"+padding, device, imsize=imsize_new) # we need to redefine the adjoint operator with the new image size |
| 661 | + |
| 662 | + # physics_adjoint.register_buffer("filter", None) |
| 663 | + physics_adjoint.update_parameters(filter=kernel) |
| 664 | + |
| 665 | + x = torch.randn(imsize_new, device=device, dtype=dtype, generator=rng).unsqueeze(0) |
| 666 | + y = physics_adjoint(x) |
| 667 | + err2 = (physics.A(y) - physics_adjoint.A_adjoint(y)).flatten().mean().abs() |
| 668 | + assert err2 < 1e-6 |
| 669 | + |
604 | 670 |
|
605 | 671 | @pytest.mark.parametrize("name", OPERATORS) |
606 | 672 | @pytest.mark.parametrize("wrapper", WRAPPERS) |
|
0 commit comments