Skip to content

Commit 3569ab3

Browse files
authored
Fixing tests on GPU (deepinv#569)
* create pr * fix test_dataloader_formats * fix test_trainer_physics_generator_params * fix test_trainer_multidatasets * fix test_trainer_identity * black * fix test_DEQ * fix test_inpainting_generators * fix test_metrics * fix test_wavelet_decomposition * black * fix test_noise_model * black * fix None-related bug * making rng states buffers again * specify dtype as well * revert changes made in test_wavelet_decomposition * Revert "revert changes made in test_wavelet_decomposition" This reverts commit 9bc2a3b.
1 parent 55bc94f commit 3569ab3

File tree

8 files changed

+92
-21
lines changed

8 files changed

+92
-21
lines changed

deepinv/loss/metric/distortion.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -680,16 +680,28 @@ def metric(self, x_net: Tensor = None, x: Tensor = None, *args, **kwargs) -> Ten
680680
coeffs_deg_y = self._haar_wavelet_decompose(deg_y, n_scales)
681681
if is_color_image:
682682
coefficients_ref_i = torch.abs(
683-
self._convolve2d(ref_i, torch.ones((2, 2)) / 4.0)
683+
self._convolve2d(
684+
ref_i,
685+
torch.ones((2, 2), device=ref_i.device, dtype=ref_i.dtype) / 4.0,
686+
)
684687
)
685688
coefficients_deg_i = torch.abs(
686-
self._convolve2d(deg_i, torch.ones((2, 2)) / 4.0)
689+
self._convolve2d(
690+
deg_i,
691+
torch.ones((2, 2), device=deg_i.device, dtype=deg_i.dtype) / 4.0,
692+
)
687693
)
688694
coefficients_ref_q = torch.abs(
689-
self._convolve2d(ref_q, torch.ones((2, 2)) / 4.0)
695+
self._convolve2d(
696+
ref_q,
697+
torch.ones((2, 2), device=ref_q.device, dtype=ref_q.dtype) / 4.0,
698+
)
690699
)
691700
coefficients_deg_q = torch.abs(
692-
self._convolve2d(deg_q, torch.ones((2, 2)) / 4.0)
701+
self._convolve2d(
702+
deg_q,
703+
torch.ones((2, 2), device=deg_q.device, dtype=deg_q.dtype) / 4.0,
704+
)
693705
)
694706

695707
B, _, H, W = ref_y.shape

deepinv/physics/generator/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ def __init__(
7373
device
7474
), f"The random generator is not on the same device as the Physics Generator. Got random generator on {rng.device} and the Physics Generator named {self.__class__.__name__} on {self.device}."
7575
self.rng = rng
76+
77+
# NOTE: There is no use in moving RNG states from one device to another
78+
# as Generator.set_state only supports inputs living on the CPU. Yet,
79+
# by registering the initial random state as a buffer, it might be
80+
# moved to another device. This might hinder performance as the tensor
81+
# will need to be moved back to the CPU if it needs to be used later.
82+
# We could fix that by letting it be a regular class attribute instead
83+
# of a buffer but it would prevent it from being included in the
84+
# state dicts which is undesirable.
7685
self.register_buffer("initial_random_state", self.rng.get_state().to(device))
7786

7887
# Set attributes
@@ -114,7 +123,8 @@ def reset_rng(self):
114123
r"""
115124
Reset the random number generator to its initial state.
116125
"""
117-
self.rng.set_state(self.initial_random_state)
126+
# NOTE: Generator.set_state expects a tensor living on the CPU.
127+
self.rng.set_state(self.initial_random_state.cpu())
118128

119129
def __add__(self, other):
120130
r"""

deepinv/physics/noise.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def __init__(self, noise_model: Callable = None, rng: torch.Generator = None):
2222
self.noise_model = noise_model
2323
self.rng = rng
2424
if rng is not None:
25+
# NOTE: There is no use in moving RNG states from one device to another
26+
# as Generator.set_state only supports inputs living on the CPU. Yet,
27+
# by registering the initial random state as a buffer, it might be
28+
# moved to another device. This might hinder performance as the tensor
29+
# will need to be moved back to the CPU if it needs to be used later.
30+
# We could fix that by letting it be a regular class attribute instead
31+
# of a buffer but it would prevent it from being included in the
32+
# state dicts which is undesirable.
2533
self.register_buffer("initial_random_state", rng.get_state())
2634

2735
def forward(self, input: torch.Tensor, seed: int = None) -> torch.Tensor:
@@ -71,7 +79,8 @@ def reset_rng(self):
7179
Reset the random number generator to its initial state.
7280
"""
7381
if self.rng is not None:
74-
self.rng.set_state(self.initial_random_state)
82+
# NOTE: Generator.set_state expects a tensor living on the CPU.
83+
self.rng.set_state(self.initial_random_state.cpu())
7584
else:
7685
warnings.warn(
7786
"Cannot reset state for random number generator because it was not initialized. This is ignored."
@@ -511,10 +520,23 @@ def forward(self, x, gain=None, seed: int = None, **kwargs):
511520
self.to(x.device)
512521
gain = self.gain[(...,) + (None,) * (x.dim() - 1)]
513522

514-
y = torch.poisson(
515-
torch.clip(x / gain, min=0.0) if self.clip_positive else x / gain,
516-
generator=self.rng,
517-
)
523+
if self.clip_positive:
524+
z = torch.clip(x / gain, min=0.0)
525+
else:
526+
# NOTE: PyTorch operations are generally run asynchronously on CUDA
527+
# devices and the underlying CUDA kernel under
528+
# torch.poisson typically raises a CUDA-level assertion error
529+
# when its input has negative entries. Those errors can't be
530+
# recovered from using Python's exception system due to their
531+
# asynchronous nature. For this reason we add a manual check if the
532+
# RNG is on a CUDA device.
533+
if self.rng is not None and self.rng.device.type == "cuda":
534+
assert gain > 0, "Gain must be positive"
535+
assert torch.all(x >= 0), "Input tensor must be non-negative"
536+
537+
z = x / gain
538+
539+
y = torch.poisson(z, generator=self.rng)
518540
if self.normalize:
519541
y = y * gain
520542
return y
@@ -618,6 +640,17 @@ def forward(self, x, gain=None, sigma=None, seed: int = None, **kwargs):
618640
if self.clip_positive:
619641
y = torch.poisson(torch.clip(x / gain, min=0.0), generator=self.rng) * gain
620642
else:
643+
# NOTE: PyTorch operations are generally run asynchronously on CUDA
644+
# devices and the underlying CUDA kernel under
645+
# torch.poisson typically raises a CUDA-level assertion error
646+
# when its input has negative entries. Those errors can't be
647+
# recovered from using Python's exception system due to their
648+
# asynchronous nature. For this reason we add a manual check if the
649+
# RNG is on a CUDA device.
650+
if self.rng is not None and self.rng.device.type == "cuda":
651+
assert gain > 0, "Gain must be positive"
652+
assert torch.all(x >= 0), "Input tensor must be non-negative"
653+
621654
y = torch.poisson(x / gain, generator=self.rng) * gain
622655

623656
y = y + self.randn_like(x) * sigma

deepinv/tests/test_generators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def choose_inpainting_generator(name, img_size, split_ratio, pixelwise, device,
372372
return dinv.physics.generator.MultiplicativeSplittingMaskGenerator(
373373
img_size=img_size,
374374
split_generator=mri_gen,
375+
device=device,
375376
)
376377
else:
377378
raise Exception("The generator chosen doesn't exist")

deepinv/tests/test_loss.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def model(x):
100100
assert torch.allclose(regfnel2, reg_fne_target, rtol=1e-3)
101101

102102

103-
def choose_loss(loss_name, rng=None):
103+
def choose_loss(loss_name, rng=None, device="cpu"):
104104
loss = []
105105
if loss_name == "mcei":
106106
loss.append(dinv.loss.MCLoss())
@@ -115,7 +115,7 @@ def choose_loss(loss_name, rng=None):
115115
"installed with `pip install kornia`",
116116
)
117117
loss.append(dinv.loss.MCLoss())
118-
loss.append(dinv.loss.EILoss(dinv.transform.Homography()))
118+
loss.append(dinv.loss.EILoss(dinv.transform.Homography(device=device)))
119119
elif loss_name == "splittv":
120120
loss.append(dinv.loss.SplittingLoss(split_ratio=0.25))
121121
loss.append(dinv.loss.TVLoss())
@@ -293,7 +293,7 @@ def test_losses(
293293
non_blocking_plots, loss_name, tmp_path, dataset, physics, imsize, device, rng
294294
):
295295
# choose training losses
296-
loss = choose_loss(loss_name, rng)
296+
loss = choose_loss(loss_name, rng, device=device)
297297

298298
save_dir = tmp_path / "dataset"
299299
# choose backbone denoiser
@@ -434,6 +434,7 @@ def test_measplit(device, loss_name, rng):
434434
dinv.physics.generator.BernoulliSplittingMaskGenerator(
435435
imsize, 0.5, device=device, rng=rng
436436
),
437+
device=device,
437438
)
438439
loss = dinv.loss.mri.WeightedSplittingLoss(
439440
mask_generator=gen, physics_generator=physics.gen
@@ -444,6 +445,7 @@ def test_measplit(device, loss_name, rng):
444445
dinv.physics.generator.BernoulliSplittingMaskGenerator(
445446
imsize, 0.5, device=device, rng=rng
446447
),
448+
device=device,
447449
)
448450
loss = dinv.loss.mri.RobustSplittingLoss(
449451
mask_generator=gen,

deepinv/tests/test_models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,14 +493,22 @@ def test_wavelet_decomposition(channels, dimension, batch_size, device):
493493
# 1 decomposition
494494
out = model.dwt(x)
495495
x_hat = model.iwt(out)
496-
assert x_hat.shape == x.shape and torch.allclose(x, x_hat, rtol=1e-5, atol=1e-5)
496+
497+
# For some reason the precision is more than 100x lower on GPU.
498+
tol = 1e-3 if torch.device(device).type == "cuda" else 1e-5
499+
500+
# NOTE: Tensors are broadcasted in torch.allclose so
501+
# they might pass the test even if they have different shapes. For this
502+
# reason we also check the shapes.
503+
assert x_hat.shape == x.shape
504+
assert torch.allclose(x, x_hat, rtol=tol, atol=tol)
497505

498506
# 2 decomposition
499507
cA1, cD1 = model.dwt(x)
500508
cA2, cD2 = model.dwt(cA1)
501509

502510
x_hat = model.iwt((cA2, cD2, cD1))
503-
assert torch.allclose(x, x_hat, rtol=1e-5, atol=1e-5)
511+
assert torch.allclose(x, x_hat, rtol=tol, atol=tol)
504512

505513

506514
def test_drunet_inputs(imsize_1_channel, device):

deepinv/tests/test_trainer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def step(self, batch_size=1, seed=None, **kwargs):
110110
"f": torch.rand((batch_size,), generator=self.rng, device=device).item()
111111
}
112112

113-
return DummyPhysicsGenerator(rng=rng)
113+
return DummyPhysicsGenerator(rng=rng, device=device)
114114

115115

116116
@pytest.mark.parametrize(
@@ -262,7 +262,7 @@ def test_trainer_physics_generator_params(
262262
):
263263
N = 10
264264
rng1 = rng
265-
rng2 = torch.Generator().manual_seed(0)
265+
rng2 = torch.Generator(device).manual_seed(0)
266266

267267
class DummyPhysics(Physics):
268268
# Dummy physics which sums images, and multiplies by a parameter f
@@ -377,6 +377,7 @@ def forward(self, y=0.0, physics=None, **kwargs):
377377
return self.dummy_param * y
378378

379379
dummy_model = DummyModel()
380+
dummy_model.to(device)
380381
optimizer = torch.optim.Adam(dummy_model.parameters(), lr=1e-2, weight_decay=0.0)
381382

382383
trainer = Trainer(
@@ -439,6 +440,7 @@ def forward(self, y=0.0, physics=None, **kwargs):
439440
return self.dummy_param * torch.ones_like(y)
440441

441442
dummy_model = DummyModel()
443+
dummy_model.to(device)
442444
optimizer = torch.optim.Adam(dummy_model.parameters(), lr=1e-2, weight_decay=0.0)
443445

444446
trainer = Trainer(
@@ -574,9 +576,11 @@ def __len__(self):
574576

575577
def __getitem__(self, i):
576578
params = generator.step(1)
579+
# NOTE: The test relies on changing params in place.
577580
params["mask"] = params["mask"].squeeze(0)
578-
x = torch.ones(imsize)
579-
y = x * params["mask"]
581+
mask = params["mask"]
582+
x = torch.ones(imsize, device=mask.device, dtype=mask.dtype)
583+
y = x * mask
580584
if ground_truth:
581585
if measurements:
582586
if generate_params:
@@ -627,8 +631,8 @@ def __getitem__(self, i):
627631

628632
# fmt: off
629633
def assert_x_none(x): assert x is None
630-
def assert_x_full(x): assert x.mean() == 1.
631-
def assert_physics_unchanged(physics): assert physics.mask.mean() == 1. # params not loaded
634+
def assert_x_full(x): assert math.isclose(x.mean(), 1.0, abs_tol=1e-7)
635+
def assert_physics_unchanged(physics): assert math.isclose(physics.mask.mean(), 1.0, abs_tol=1e-7) # params not loaded
632636
def assert_physics_offline(physics): assert physics.mask.mean() < .2
633637
def assert_physics_online(physics): assert physics.mask.mean() > .8
634638
def assert_y_offline(y): assert y.mean() < .2

deepinv/tests/test_unfolded.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def test_DEQ(unfolded_algo, imsize, dummy_dataset, device):
124124
anderson_acceleration_backward=and_acc,
125125
jacobian_free=jac_free,
126126
)
127+
model.to(device)
127128

128129
for idx, (name, param) in enumerate(model.named_parameters()):
129130
assert param.requires_grad

0 commit comments

Comments
 (0)