-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathmeaniou.py
More file actions
147 lines (120 loc) · 6.84 KB
/
meaniou.py
File metadata and controls
147 lines (120 loc) · 6.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from monai.metrics.utils import do_metric_reduction, ignore_background
from monai.utils import MetricReduction
from .metric import CumulativeIterationMetric
class MeanIoU(CumulativeIterationMetric):
"""
Compute average Intersection over Union (IoU) score between two tensors.
It supports both multi-classes and multi-labels tasks.
Input `y_pred` is compared with ground truth `y`.
`y_pred` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms
in ``monai.transforms.post`` first to achieve binarized values.
The `include_background` parameter can be set to ``False`` to exclude
the first category (channel index 0) which is by convention assumed to be background. If the non-background
segmentations are small compared to the total image size they can get overwhelmed by the signal from the
background.
`y_pred` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
Args:
include_background: whether to include IoU computation on the first channel of
the predicted output. Defaults to ``True``.
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
"""
def __init__(
self,
include_background: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
ignore_empty: bool = True,
) -> None:
super().__init__()
self.include_background = include_background
self.reduction = reduction
self.get_not_nans = get_not_nans
self.ignore_empty = ignore_empty
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Args:
y_pred: input data to compute, typical segmentation model output.
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
should be binarized.
y: ground truth to compute mean IoU metric. It must be one-hot format and first dim is batch.
The values should be binarized.
Raises:
ValueError: when `y_pred` has less than three dimensions.
"""
dims = y_pred.ndimension()
if dims < 3:
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")
# compute IoU (BxC) for each channel for each batch
return compute_iou(
y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty
)
def aggregate(
self, reduction: MetricReduction | str | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Execute reduction logic for the output of `compute_iou`.
Args:
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("the data to aggregate must be PyTorch Tensor.")
# do metric reduction
f, not_nans = do_metric_reduction(data, reduction or self.reduction)
return (f, not_nans) if self.get_not_nans else f
def compute_iou(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True
) -> torch.Tensor:
"""Computes Intersection over Union (IoU) score metric from a batch of predictions.
Args:
y_pred: input data to compute, typical segmentation model output.
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
should be binarized.
y: ground truth to compute mean IoU metric. It must be one-hot format and first dim is batch.
The values should be binarized.
include_background: whether to include IoU computation on the first channel of
the predicted output. Defaults to True.
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
Returns:
IoU scores per batch and per class, (shape [batch_size, num_classes]).
Raises:
ValueError: when `y_pred` and `y` have different shapes.
"""
if not include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)
if y.shape != y_pred.shape:
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
# reducing only spatial dimensions (not batch nor channels)
n_len = len(y_pred.shape)
reduce_axis = list(range(2, n_len))
intersection = torch.sum(y * y_pred, dim=reduce_axis)
y_o = torch.sum(y, reduce_axis)
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
union = y_o + y_pred_o - intersection
if ignore_empty:
return torch.where(y_o > 0, (intersection) / union, torch.tensor(float("nan"), device=y_o.device))
return torch.where(union > 0, (intersection) / union, torch.tensor(1.0, device=y_o.device))