Skip to content

Commit 105c3b8

Browse files
committed
Fixes typing issues in perceptual loss
1 parent 086b8a9 commit 105c3b8

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

monai/losses/perceptual.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818

1919
LPIPS, _ = optional_import("lpips", name="LPIPS")
2020
torchvision, _ = optional_import("torchvision")
21-
from torchvision.models import ResNet50_Weights, resnet50
22-
from torchvision.models.feature_extraction import create_feature_extractor
21+
2322

2423

2524
class PerceptualLoss(nn.Module):
@@ -79,6 +78,7 @@ def __init__(
7978
torch.hub.set_dir(cache_dir)
8079

8180
self.spatial_dims = spatial_dims
81+
self.perceptual_function : nn.Module
8282
if spatial_dims == 3 and is_fake_3d is False:
8383
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
8484
elif "radimagenet_" in network_type:
@@ -168,7 +168,7 @@ class MedicalNetPerceptualSimilarity(nn.Module):
168168
def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
169169
super().__init__()
170170
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
171-
self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose)
171+
self.model = torch.hub.load("marksgraham/MedicalNet-models", model=net, verbose=verbose, force_reload=True)
172172
self.eval()
173173

174174
for param in self.parameters():
@@ -196,7 +196,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
196196
feats_input = normalize_tensor(outs_input)
197197
feats_target = normalize_tensor(outs_target)
198198

199-
results = (feats_input - feats_target) ** 2
199+
results : torch.Tensor = (feats_input - feats_target) ** 2
200200
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
201201

202202
return results
@@ -266,7 +266,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
266266
feats_input = normalize_tensor(outs_input)
267267
feats_target = normalize_tensor(outs_target)
268268

269-
results = (feats_input - feats_target) ** 2
269+
results: torch.Tensor = (feats_input - feats_target) ** 2
270270
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
271271

272272
return results
@@ -345,7 +345,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
345345
feats_input = normalize_tensor(outs_input)
346346
feats_target = normalize_tensor(outs_target)
347347

348-
results = (feats_input - feats_target) ** 2
348+
results : torch.Tensor = (feats_input - feats_target) ** 2
349349
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
350350

351351
return results

0 commit comments

Comments
 (0)