Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c976047
modify inconsistent layer and add pretrain
yiheng-wang-nv Oct 26, 2022
1a86031
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
yiheng-wang-nv Oct 26, 2022
07b60a8
fix black
yiheng-wang-nv Oct 26, 2022
4624443
Merge branch '5386-modify-hovernet-and-add-pretrain' of github.com:yi…
yiheng-wang-nv Oct 26, 2022
ab7f3f3
add freeze encoder func
yiheng-wang-nv Oct 26, 2022
385f993
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
bhashemian Oct 27, 2022
aeb3c23
add option of padding to decoder
yiheng-wang-nv Oct 27, 2022
93946ce
Merge branch '5386-modify-hovernet-and-add-pretrain' of github.com:yi…
yiheng-wang-nv Oct 27, 2022
4b6e7dd
add doc-string for pretrained encoder
yiheng-wang-nv Oct 28, 2022
e21dc7c
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
bhashemian Oct 31, 2022
76dc20b
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
bhashemian Oct 31, 2022
349f961
add unittests provided by Jonny
yiheng-wang-nv Nov 1, 2022
fb0fea2
fix docstring error
yiheng-wang-nv Nov 1, 2022
27dbf70
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
yiheng-wang-nv Nov 1, 2022
fbb64fd
add decoder padding unit tests
yiheng-wang-nv Nov 1, 2022
b8d20a4
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
yiheng-wang-nv Nov 1, 2022
52525a8
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
bhashemian Nov 2, 2022
65fa292
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
bhashemian Nov 4, 2022
38bcdcf
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
yiheng-wang-nv Nov 7, 2022
e3c5fc3
remove pretrained part and update structure
yiheng-wang-nv Nov 7, 2022
7606725
add load func
yiheng-wang-nv Nov 7, 2022
0e625ba
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
yiheng-wang-nv Nov 7, 2022
393fef1
Update docstring
yiheng-wang-nv Nov 7, 2022
29d2f88
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
yiheng-wang-nv Nov 7, 2022
a3945ed
adjust the order of docstring
yiheng-wang-nv Nov 7, 2022
bed8496
Merge branch 'dev' into 5386-modify-hovernet-and-add-pretrain
bhashemian Nov 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 117 additions & 29 deletions monai/networks/nets/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@
# }
# =========================================================================

import os
import re
import warnings
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Sequence, Type, Union

import torch
import torch.nn as nn

from monai.apps.utils import download_url
from monai.networks.blocks import UpSample
from monai.networks.layers.factories import Conv, Dropout
from monai.networks.layers.utils import get_act_layer, get_norm_layer
Expand All @@ -52,6 +56,7 @@ def __init__(
act: Union[str, tuple] = ("relu", {"inplace": True}),
norm: Union[str, tuple] = "batch",
kernel_size: int = 3,
padding: int = 0,
) -> None:
"""
Args:
Expand All @@ -62,6 +67,7 @@ def __init__(
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
kernel_size: size of the kernel for >1 convolutions (dependent on mode)
padding: padding value for >1 convolutions.
"""
super().__init__()

Expand All @@ -76,7 +82,8 @@ def __init__(
self.layers.add_module("conv1/norm", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
self.layers.add_module("conv1/relu2", get_act_layer(name=act))
self.layers.add_module(
"conv2", conv_type(num_features, out_channels, kernel_size=kernel_size, padding=0, groups=4, bias=False)
"conv2",
conv_type(num_features, out_channels, kernel_size=kernel_size, padding=padding, groups=4, bias=False),
)

if dropout_prob > 0:
Expand All @@ -85,7 +92,7 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:

x1 = self.layers(x)
if x1.shape != x.shape:
if x1.shape[-1] != x.shape[-1]:
trim = (x.shape[-1] - x1.shape[-1]) // 2
x = x[:, :, trim:-trim, trim:-trim]

Expand All @@ -105,6 +112,7 @@ def __init__(
act: Union[str, tuple] = ("relu", {"inplace": True}),
norm: Union[str, tuple] = "batch",
kernel_size: int = 3,
same_padding: bool = False,
) -> None:
"""
Args:
Expand All @@ -116,17 +124,30 @@ def __init__(
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
kernel_size: size of the kernel for >1 convolutions (dependent on mode)
same_padding: whether to do padding for >1 convolutions to ensure
the output size is the same as the input size.
"""
super().__init__()

conv_type: Callable = Conv[Conv.CONV, 2]

self.add_module("conva", conv_type(in_channels, in_channels // 4, kernel_size=kernel_size, bias=False))
padding: int = kernel_size // 2 if same_padding else 0

self.add_module(
"conva", conv_type(in_channels, in_channels // 4, kernel_size=kernel_size, padding=padding, bias=False)
)

_in_channels = in_channels // 4
for i in range(layers):
layer = _DenseLayerDecoder(
num_features, _in_channels, out_channels, dropout_prob, act=act, norm=norm, kernel_size=kernel_size
num_features,
_in_channels,
out_channels,
dropout_prob,
act=act,
norm=norm,
kernel_size=kernel_size,
padding=padding,
)
_in_channels += out_channels
self.add_module("denselayerdecoder%d" % (i + 1), layer)
Expand Down Expand Up @@ -172,22 +193,24 @@ def __init__(
dropout_type: Callable = Dropout[Dropout.DROPOUT, 2]

if not drop_first_norm_relu:
self.layers.add_module("preact_norm", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))
self.layers.add_module("preact_relu", get_act_layer(name=act))
self.layers.add_module("preact/bn", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))
self.layers.add_module("preact/relu", get_act_layer(name=act))

self.layers.add_module("conv1", conv_type(in_channels, num_features, kernel_size=1, padding=0, bias=False))
self.layers.add_module("norm2", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
self.layers.add_module("relu2", get_act_layer(name=act))
self.layers.add_module("conv1/bn", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
self.layers.add_module("conv1/relu", get_act_layer(name=act))

if in_channels != 64 and drop_first_norm_relu:
self.layers.add_module(
"conv2", conv_type(num_features, num_features, kernel_size=kernel_size, stride=2, padding=2, bias=False)
)
else:
self.layers.add_module("conv2", conv_type(num_features, num_features, kernel_size=1, padding=0, bias=False))
self.layers.add_module(
"conv2", conv_type(num_features, num_features, kernel_size=kernel_size, padding=1, bias=False)
)

self.layers.add_module("norm3", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
self.layers.add_module("relu3", get_act_layer(name=act))
self.layers.add_module("conv2/bn", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
self.layers.add_module("conv2/relu", get_act_layer(name=act))
self.layers.add_module("conv3", conv_type(num_features, out_channels, kernel_size=1, padding=0, bias=False))

if dropout_prob > 0:
Expand All @@ -206,7 +229,7 @@ def __init__(
"""
super().__init__()

self.add_module("norm", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))
self.add_module("bn", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))
self.add_module("relu", get_act_layer(name=act))


Expand Down Expand Up @@ -250,11 +273,11 @@ def __init__(
layer = _DenseLayer(
num_features, in_channels, out_channels, dropout_prob, act=act, norm=norm, drop_first_norm_relu=True
)
self.layers.add_module("prim_denselayer_1", layer)
self.layers.add_module("denselayer_0", layer)

for i in range(1, layers):
layer = _DenseLayer(num_features, out_channels, out_channels, dropout_prob, act=act, norm=norm)
self.layers.add_module(f"main_dense_layer_{i + 1}", layer)
self.layers.add_module(f"denselayer_{i}", layer)

self.bna_block = _Transition(out_channels, act=act, norm=norm)

Expand Down Expand Up @@ -287,6 +310,7 @@ def __init__(
dropout_prob: float = 0.0,
out_channels: int = 2,
kernel_size: int = 3,
same_padding: bool = False,
) -> None:
"""
Args:
Expand All @@ -296,6 +320,8 @@ def __init__(
dropout_prob: dropout rate after each dense layer.
out_channels: number of the output channel.
kernel_size: size of the kernel for >1 convolutions (dependent on mode)
same_padding: whether to do padding for >1 convolutions to ensure
the output size is the same as the input size.
"""
super().__init__()
conv_type: Callable = Conv[Conv.CONV, 2]
Expand All @@ -316,6 +342,7 @@ def __init__(
act=act,
norm=norm,
kernel_size=kernel_size,
same_padding=same_padding,
)
self.decoder_blocks.add_module(f"decoderblock{i + 1}", block)
_in_channels = 512
Expand All @@ -335,7 +362,7 @@ def __init__(
_seq_block = nn.Sequential(
OrderedDict(
[
("norm", get_norm_layer(name=norm, spatial_dims=2, channels=64)),
("bn", get_norm_layer(name=norm, spatial_dims=2, channels=64)),
("relu", get_act_layer(name=act)),
("conv", conv_type(64, out_channels, kernel_size=1, stride=1)),
]
Expand All @@ -358,7 +385,8 @@ def forward(self, xin: torch.Tensor, short_cuts: List[torch.Tensor]) -> torch.Te
x = self.upsample(x)
block_number -= 1
trim = (short_cuts[block_number].shape[-1] - x.shape[-1]) // 2
x += short_cuts[block_number][:, :, trim:-trim, trim:-trim]
if trim > 0:
x += short_cuts[block_number][:, :, trim:-trim, trim:-trim]

for block in self.output_features:
x = block(x)
Expand All @@ -375,14 +403,26 @@ class HoVerNet(nn.Module):
and classification of nuclei in multi-tissue histology images,
Medical Image Analysis 2019

https://github.com/vqdang/hover_net

Args:
mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or
a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`.
in_channels: number of the input channel.
np_out_channels: number of the output channel of the nucleus prediction branch.
out_classes: number of the nuclear type classes.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
decoder_padding: whether to do padding on convolution layers in the decoders. In the conic branch
of the referred repository, the architecture is changed to do padding on convolution layers in order to
get the same output size as the input, and this changed version is used on CoNIC challenge.
Please note that to get consistent output size, `HoVerNetMode.FAST` mode should be employed.
dropout_prob: dropout rate after each dense layer.
pretrained_url: if specifying, will loaded the pretrained weights downloaded from the url.
The weights should be ImageNet pretrained preact-resnet50 weights coming from the referred hover_net
repository, each user is responsible for checking the content of model/datasets and the applicable licenses
and determining if suitable for the intended use. please check the following link for more details:
https://github.com/vqdang/hover_net#data-format
"""

Mode = HoVerNetMode
Expand All @@ -392,10 +432,13 @@ def __init__(
self,
mode: Union[HoVerNetMode, str] = HoVerNetMode.FAST,
in_channels: int = 3,
np_out_channels: int = 2,
out_classes: int = 0,
act: Union[str, tuple] = ("relu", {"inplace": True}),
norm: Union[str, tuple] = "batch",
decoder_padding: bool = False,
dropout_prob: float = 0.0,
pretrained_url: Optional[str] = None,
) -> None:

super().__init__()
Expand All @@ -404,6 +447,11 @@ def __init__(
mode = mode.upper()
self.mode = look_up_option(mode, HoVerNetMode)

if self.mode == "ORIGINAL" and decoder_padding is True:
warnings.warn(
"'decoder_padding=True' only works when mode is 'FAST', otherwise the output size may not equal to the input."
)

if out_classes > 128:
raise ValueError("Number of nuclear types classes exceeds maximum (128)")
elif out_classes == 1:
Expand All @@ -426,15 +474,12 @@ def __init__(

conv_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2]

self.input_features = nn.Sequential(
self.conv0 = nn.Sequential(
OrderedDict(
[
(
"conv0",
conv_type(in_channels, _init_features, kernel_size=7, stride=1, padding=_pad, bias=False),
),
("norm0", get_norm_layer(name=norm, spatial_dims=2, channels=_init_features)),
("relu0", get_act_layer(name=act)),
("conv", conv_type(in_channels, _init_features, kernel_size=7, stride=1, padding=_pad, bias=False)),
("bn", get_norm_layer(name=norm, spatial_dims=2, channels=_init_features)),
("relu", get_act_layer(name=act)),
]
)
)
Expand All @@ -455,7 +500,7 @@ def __init__(
act=act,
norm=norm,
)
self.res_blocks.add_module(f"residualblock{i + 1}", block)
self.res_blocks.add_module(f"d{i}", block)

_in_channels = _out_channels
_out_channels *= 2
Expand All @@ -471,10 +516,14 @@ def __init__(
)

# decode branches
self.nucleus_prediction = _DecoderBranch(kernel_size=_ksize)
self.horizontal_vertical = _DecoderBranch(kernel_size=_ksize)
self.nucleus_prediction = _DecoderBranch(
kernel_size=_ksize, same_padding=decoder_padding, out_channels=np_out_channels
)
self.horizontal_vertical = _DecoderBranch(kernel_size=_ksize, same_padding=decoder_padding)
self.type_prediction: Optional[_DecoderBranch] = (
_DecoderBranch(out_channels=out_classes, kernel_size=_ksize) if out_classes > 0 else None
_DecoderBranch(out_channels=out_classes, kernel_size=_ksize, same_padding=decoder_padding)
if out_classes > 0
else None
)

for m in self.modules():
Expand All @@ -484,6 +533,12 @@ def __init__(
nn.init.constant_(torch.as_tensor(m.weight), 1)
nn.init.constant_(torch.as_tensor(m.bias), 0)

if pretrained_url is not None:
_load_pretrained_encoder(self, pretrained_url)

def freeze_encoder(self):
self.res_blocks.requires_grad_(False)

def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:

if self.mode == HoVerNetMode.ORIGINAL.value:
Expand All @@ -493,8 +548,7 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
if x.shape[-1] != 256 or x.shape[-2] != 256:
raise ValueError("Input size should be 256 x 256 when using HoVerNetMode.FAST")

x = x / 255.0 # to 0-1 range to match XY
x = self.input_features(x)
x = self.conv0(x)
short_cuts = []

for i, block in enumerate(self.res_blocks):
Expand All @@ -516,4 +570,38 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
return output


def _load_pretrained_encoder(model: nn.Module, model_url: str):

pattern_conv0 = re.compile(r"^(conv0\.\/)(.+)$")
pattern_block = re.compile(r"^(d\d+)\.(.+)$")
pattern_layer = re.compile(r"^(.+\.d\d+)\.units\.(\d+)(.+)$")
pattern_bna = re.compile(r"^(.+\.d\d+)\.blk_bna\.(.+)")
# download the pretrained weights into torch hub's default dir
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
state_dict = torch.load(weights_dir, map_location=None)["desc"]
for key in list(state_dict.keys()):
new_key = None
if pattern_conv0.match(key):
new_key = re.sub(pattern_conv0, r"conv0.conv\2", key)
elif pattern_block.match(key):
new_key = re.sub(pattern_block, r"res_blocks.\1.\2", key)
if pattern_layer.match(new_key):
new_key = re.sub(pattern_layer, r"\1.layers.denselayer_\2.layers\3", new_key)
elif pattern_bna.match(new_key):
new_key = re.sub(pattern_bna, r"\1.bna_block.\2", new_key)
if new_key:
state_dict[new_key] = state_dict[key]
del state_dict[key]
if "upsample2x" in key:
del state_dict[key]

model_dict = model.state_dict()
state_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
}
model_dict.update(state_dict)
model.load_state_dict(model_dict)


Hovernet = HoVernet = HoverNet = HoVerNet
Loading