77import deepinv as dinv
88from deepinv .physics import LinearPhysicsMultiScaler , PhysicsCropper
99from deepinv .utils .tensorlist import TensorList
10- from deepinv .models .base import Reconstructor
10+ from deepinv .models .base import Reconstructor , Denoiser
1111
1212
13- class RAM (Reconstructor ):
13+ class RAM (Reconstructor , Denoiser ):
1414 r"""
1515 Reconstruct Anything Model.
1616
@@ -109,10 +109,10 @@ def constant2map(self, value, x):
109109 )
110110 return value_map
111111
112- def base_conditioning (self , x , sigma , gamma ):
112+ def base_conditioning (self , x , sigma , gain ):
113113 noise_level_map = self .constant2map (sigma , x )
114- gamma_map = self .constant2map (gamma , x )
115- return torch .cat ((x , noise_level_map , gamma_map ), 1 )
114+ gain_map = self .constant2map (gain , x )
115+ return torch .cat ((x , noise_level_map , gain_map ), 1 )
116116
117117 def realign_input (self , x , physics , y ):
118118 r"""
@@ -165,7 +165,7 @@ def realign_input(self, x, physics, y):
165165
166166 return model_input
167167
168- def forward_unet (self , x0 , sigma = None , gamma = None , physics = None , y = None ):
168+ def forward_unet (self , x0 , sigma = None , gain = None , physics = None , y = None ):
169169 r"""
170170 Forward pass of the UNet model.
171171
@@ -186,7 +186,7 @@ def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None):
186186 if y is not None :
187187 x0 = self .realign_input (x0 , physics , y )
188188
189- x0 = self .base_conditioning (x0 , sigma , gamma )
189+ x0 = self .base_conditioning (x0 , sigma , gain )
190190
191191 x1 = self .m_head (x0 )
192192
@@ -214,17 +214,25 @@ def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None):
214214
215215 return x
216216
217- def forward (self , y = None , physics = None ):
217+ def forward (self , y , physics = None , sigma = None , gain = None ):
218218 r"""
219219 Reconstructs a signal estimate from measurements y
220220
221221 :param torch.Tensor y: measurements
222222 :param deepinv.physics.Physics physics: forward operator
223223 :return: torch.Tensor: reconstructed signal estimate
224224 """
225+ assert (
226+ physics is not None or sigma is not None or gain is not None
227+ ), "Either physics, sigma or gain must be provided to the RAM model."
228+
225229 if physics is None :
230+ gain = 1e-3 if gain is None else gain
231+ sigma = self .sigma_threshold if sigma is None else sigma
232+
226233 physics = dinv .physics .Denoising (
227- noise_model = dinv .physics .GaussianNoise (sigma = 0.0 ), device = y .device
234+ noise_model = dinv .physics .PoissonGaussianNoise (sigma = sigma , gain = gain ),
235+ device = y .device ,
228236 )
229237
230238 x_temp = physics .A_adjoint (y )
@@ -236,13 +244,14 @@ def forward(self, y=None, physics=None):
236244 sigma = (
237245 physics .noise_model .sigma if hasattr (physics .noise_model , "sigma" ) else 1e-3
238246 )
239- sigma = torch .tensor (max (sigma , self .sigma_threshold ), device = y .device )
240- gamma = (
247+ sigma = self ._handle_sigma (max (sigma , self .sigma_threshold ))
248+
249+ gain = (
241250 physics .noise_model .gain if hasattr (physics .noise_model , "gain" ) else 1e-3
242251 )
243- gamma = torch . tensor (max (gamma , 1e-3 ), device = y . device )
252+ gain = self . _handle_sigma (max (gain , 1e-3 ))
244253
245- out = self .forward_unet (x_in , sigma = sigma , gamma = gamma , physics = physics , y = y )
254+ out = self .forward_unet (x_in , sigma = sigma , gain = gain , physics = physics , y = y )
246255
247256 out = physics .remove_pad (out )
248257
0 commit comments