|
18 | 18 |
|
19 | 19 | LPIPS, _ = optional_import("lpips", name="LPIPS") |
20 | 20 | torchvision, _ = optional_import("torchvision") |
21 | | -from torchvision.models import ResNet50_Weights, resnet50 |
22 | | -from torchvision.models.feature_extraction import create_feature_extractor |
| 21 | + |
23 | 22 |
|
24 | 23 |
|
25 | 24 | class PerceptualLoss(nn.Module): |
@@ -79,6 +78,7 @@ def __init__( |
79 | 78 | torch.hub.set_dir(cache_dir) |
80 | 79 |
|
81 | 80 | self.spatial_dims = spatial_dims |
| 81 | + self.perceptual_function : nn.Module |
82 | 82 | if spatial_dims == 3 and is_fake_3d is False: |
83 | 83 | self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) |
84 | 84 | elif "radimagenet_" in network_type: |
@@ -168,7 +168,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): |
168 | 168 | def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: |
169 | 169 | super().__init__() |
170 | 170 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
171 | | - self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) |
| 171 | + self.model = torch.hub.load("marksgraham/MedicalNet-models", model=net, verbose=verbose, force_reload=True) |
172 | 172 | self.eval() |
173 | 173 |
|
174 | 174 | for param in self.parameters(): |
@@ -196,7 +196,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
196 | 196 | feats_input = normalize_tensor(outs_input) |
197 | 197 | feats_target = normalize_tensor(outs_target) |
198 | 198 |
|
199 | | - results = (feats_input - feats_target) ** 2 |
| 199 | + results : torch.Tensor = (feats_input - feats_target) ** 2 |
200 | 200 | results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) |
201 | 201 |
|
202 | 202 | return results |
@@ -266,7 +266,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
266 | 266 | feats_input = normalize_tensor(outs_input) |
267 | 267 | feats_target = normalize_tensor(outs_target) |
268 | 268 |
|
269 | | - results = (feats_input - feats_target) ** 2 |
| 269 | + results: torch.Tensor = (feats_input - feats_target) ** 2 |
270 | 270 | results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) |
271 | 271 |
|
272 | 272 | return results |
@@ -345,7 +345,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
345 | 345 | feats_input = normalize_tensor(outs_input) |
346 | 346 | feats_target = normalize_tensor(outs_target) |
347 | 347 |
|
348 | | - results = (feats_input - feats_target) ** 2 |
| 348 | + results : torch.Tensor = (feats_input - feats_target) ** 2 |
349 | 349 | results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) |
350 | 350 |
|
351 | 351 | return results |
|
0 commit comments