Skip to content

Commit 7c17b77

Browse files
committed
ram denoise + gamma -> gain
1 parent 088bd41 commit 7c17b77

File tree

2 files changed

+55
-13
lines changed

2 files changed

+55
-13
lines changed

deepinv/models/ram.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import deepinv as dinv
88
from deepinv.physics import LinearPhysicsMultiScaler, PhysicsCropper
99
from deepinv.utils.tensorlist import TensorList
10-
from deepinv.models.base import Reconstructor
10+
from deepinv.models.base import Reconstructor, Denoiser
1111

1212

13-
class RAM(Reconstructor):
13+
class RAM(Reconstructor, Denoiser):
1414
r"""
1515
Reconstruct Anything Model.
1616
@@ -109,10 +109,10 @@ def constant2map(self, value, x):
109109
)
110110
return value_map
111111

112-
def base_conditioning(self, x, sigma, gamma):
112+
def base_conditioning(self, x, sigma, gain):
113113
noise_level_map = self.constant2map(sigma, x)
114-
gamma_map = self.constant2map(gamma, x)
115-
return torch.cat((x, noise_level_map, gamma_map), 1)
114+
gain_map = self.constant2map(gain, x)
115+
return torch.cat((x, noise_level_map, gain_map), 1)
116116

117117
def realign_input(self, x, physics, y):
118118
r"""
@@ -165,7 +165,7 @@ def realign_input(self, x, physics, y):
165165

166166
return model_input
167167

168-
def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None):
168+
def forward_unet(self, x0, sigma=None, gain=None, physics=None, y=None):
169169
r"""
170170
Forward pass of the UNet model.
171171
@@ -186,7 +186,7 @@ def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None):
186186
if y is not None:
187187
x0 = self.realign_input(x0, physics, y)
188188

189-
x0 = self.base_conditioning(x0, sigma, gamma)
189+
x0 = self.base_conditioning(x0, sigma, gain)
190190

191191
x1 = self.m_head(x0)
192192

@@ -214,17 +214,25 @@ def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None):
214214

215215
return x
216216

217-
def forward(self, y=None, physics=None):
217+
def forward(self, y, physics=None, sigma=None, gain=None):
218218
r"""
219219
Reconstructs a signal estimate from measurements y
220220
221221
:param torch.Tensor y: measurements
222222
:param deepinv.physics.Physics physics: forward operator
223223
:return: torch.Tensor: reconstructed signal estimate
224224
"""
225+
assert (
226+
physics is not None or sigma is not None or gain is not None
227+
), "Either physics, sigma or gain must be provided to the RAM model."
228+
225229
if physics is None:
230+
gain = 1e-3 if gain is None else gain
231+
sigma = self.sigma_threshold if sigma is None else sigma
232+
226233
physics = dinv.physics.Denoising(
227-
noise_model=dinv.physics.GaussianNoise(sigma=0.0), device=y.device
234+
noise_model=dinv.physics.PoissonGaussianNoise(sigma=sigma, gain=gain),
235+
device=y.device,
228236
)
229237

230238
x_temp = physics.A_adjoint(y)
@@ -236,13 +244,14 @@ def forward(self, y=None, physics=None):
236244
sigma = (
237245
physics.noise_model.sigma if hasattr(physics.noise_model, "sigma") else 1e-3
238246
)
239-
sigma = torch.tensor(max(sigma, self.sigma_threshold), device=y.device)
240-
gamma = (
247+
sigma = self._handle_sigma(max(sigma, self.sigma_threshold))
248+
249+
gain = (
241250
physics.noise_model.gain if hasattr(physics.noise_model, "gain") else 1e-3
242251
)
243-
gamma = torch.tensor(max(gamma, 1e-3), device=y.device)
252+
gain = self._handle_sigma(max(gain, 1e-3))
244253

245-
out = self.forward_unet(x_in, sigma=sigma, gamma=gamma, physics=physics, y=y)
254+
out = self.forward_unet(x_in, sigma=sigma, gain=gain, physics=physics, y=y)
246255

247256
out = physics.remove_pad(out)
248257

examples/unfolded/demo_ram.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,39 @@
4949
figsize=(8, 3),
5050
)
5151

52+
# %%
53+
# This model was also trained on various denoising problems, in particular on Poisson-Gaussian denoising.
54+
55+
sigma, gain = 0.2, 0.5
56+
physics = dinv.physics.Denoising(
57+
noise_model=dinv.physics.PoissonGaussianNoise(sigma=sigma, gain=gain),
58+
device=device,
59+
)
60+
61+
# generate measurement
62+
y = physics(x)
63+
64+
# run inference
65+
with torch.no_grad():
66+
x_hat = model(y, physics=physics)
67+
# or alternatively, we can use the model without physics:
68+
# x_hat = model(y, sigma=sigma, gain=gain)
69+
70+
# compute PSNR
71+
in_psnr = dinv.metric.PSNR()(x, y).item()
72+
out_psnr = dinv.metric.PSNR()(x, x_hat).item()
73+
74+
# plot
75+
dinv.utils.plot(
76+
[x, y, x_hat],
77+
[
78+
"Original",
79+
"Measurement\n PSNR = {:.2f}dB".format(in_psnr),
80+
"Reconstruction\n PSNR = {:.2f}dB".format(out_psnr),
81+
],
82+
figsize=(8, 3),
83+
)
84+
5285
# %%
5386
# This model is not trained on all degradations, so it may not perform well on all inverse problems.
5487
# For instance, it is not trained on image demosaicing. Applying it to a demosaicing problem will yield poor results,

0 commit comments

Comments
 (0)