-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
When resizing images and their corresponding segmentation masks, it is common practice to use bilinear interpolation for the images and nearest neighbor sampling for segmentation masks. However, I have observed that using the interpolate method in torch.nn.functional with mode='nearest' causes the resized segmentation masks to be shifted to the bottom right when compared to the original image.
To Reproduce
This requires PIL and OpenCV to be installed. The code given loads an image and its corresponding segmentation mask and produces 4 visualizations of the mask overlayed on top of the image:
- Without any resizing, i.e., control baseline to assess mask accuracy.
- Both image and mask are up-sampled using
torch.nn.functional.interpolatewithmode='bilinear'andalign_corners=False. - The image is upsampled as in (2), but the mask is up-sampled with
mode='nearest'(this is where the problem occurs). - The image is upsampled as in (2), but the mask is up-sampled using the
Image.resizemethod in PIL. This appears to perform much better than the setting in (3).
The image and mask used can be downloaded from:
- Image: https://www.dropbox.com/s/gkgli16v080degr/image.jpg?dl=0
- Mask: https://www.dropbox.com/s/g4vgrdsdrzlg14s/mask.png?dl=0
Point the IMAGE_PATH and MASK_PATH variables in the code below to the location of these two files.
EDIT: Also added a visualization of the absolute difference between (3) and (4). This required matplotlib.
from PIL import Image
import cv2
import numpy as np
import torch
import torch.nn.functional as F
IMAGE_PATH = "/path/to/image.jpg"
MASK_PATH = "/path/to/mask.png"
def visualize_mask_on_image(image, mask):
color = np.array([0, 255, 0], np.uint8)[None, None, :]
masked_image = np.where(mask, color, image).astype(np.float32)
masked_image = np.round((0.6 * masked_image) + (0.4 * image.astype(np.float32)))
return masked_image.astype(np.uint8)
def resize_torch(interpolation_mode, align_corners):
# open image and mask files
image = cv2.imread(IMAGE_PATH, cv2.IMREAD_COLOR)
mask = cv2.imread(MASK_PATH, cv2.IMREAD_UNCHANGED)
# convert to torch tensor
image = torch.from_numpy(image).permute(2, 0, 1).float() # HWC to CHW and byte to float
mask = torch.from_numpy(mask > 0).unsqueeze(0).float()
# upsample image and mask by 4x
image = F.interpolate(image.unsqueeze(0), size=(800, 1200), mode='bilinear', align_corners=False)
mask = F.interpolate(mask.unsqueeze(0), size=(800, 1200), mode=interpolation_mode, align_corners=align_corners)
# convert to numpy arrays
image = image.squeeze(0).permute(1, 2, 0).byte().numpy()
mask = (mask.squeeze(0).permute(1, 2, 0) > 0.5).byte().numpy()
# overlay the mask on the image
return visualize_mask_on_image(image, mask)
def resize_pil_nearest():
# open image and mask files
image = cv2.imread(IMAGE_PATH, cv2.IMREAD_COLOR)
mask = Image.open(MASK_PATH)
# convert segmentation mask to binary fg/bg mask
colormap = [0] + [1 for _ in range(255)]
mask = mask.point(colormap)
# convert to torch tensor
image = torch.from_numpy(image).permute(2, 0, 1).float() # HWC to CHW and byte to float
# upsample image and mask by 4x
image = F.interpolate(image.unsqueeze(0), size=(800, 1200), mode='bilinear', align_corners=False)
mask = np.array(mask.resize((1200, 800), resample=Image.NEAREST))
# convert to numpy arrays
image = image.squeeze(0).permute(1, 2, 0).byte().numpy()
mask = mask[:, :, None]
# overlay the mask on the image
return visualize_mask_on_image(image, mask), mask
def no_resize_baseline():
# open image and mask files
image = cv2.imread(IMAGE_PATH, cv2.IMREAD_COLOR)
mask = cv2.imread(MASK_PATH, cv2.IMREAD_UNCHANGED)[:, :, None]
return visualize_mask_on_image(image, mask), mask
def plot_torch_pil_diff():
_, mask_pil = resize_pil_nearest()
_, mask_torch = resize_torch('nearest', None)
mask_pil = mask_pil.astype(np.float32)
mask_torch = mask_torch.astype(np.float32)
diff = np.abs(mask_pil - mask_torch)[:, :, 0]
diff_img = plt.imshow(diff)
plt.colorbar(diff_img)
plt.show()
resized_torch_bilinear, _ = resize_torch('bilinear', False)
resized_torch_nearest, _ = resize_torch('nearest', None)
resized_PIL_nearest, _ = resize_pil_nearest()
no_resize, _ = no_resize_baseline()
cv2.namedWindow('Torch Bilinear', cv2.WINDOW_NORMAL)
cv2.namedWindow('Torch Nearest', cv2.WINDOW_NORMAL)
cv2.namedWindow('PIL Nearest', cv2.WINDOW_NORMAL)
cv2.namedWindow('No Resize', cv2.WINDOW_NORMAL)
cv2.imshow('Torch Bilinear', resized_torch_bilinear)
cv2.imshow('Torch Nearest', resized_torch_nearest)
cv2.imshow('PIL Nearest', resized_PIL_nearest)
cv2.imshow('No Resize', no_resize)
plot_torch_pil_diff()
cv2.waitKey(0)
Expected behavior
Settings (3) and (4) should result in roughly the same output, but they don't. Let's zoom into handle bars of the bicycles in the overlayed images to get a better idea:
Here's what setting (1) looks like (no resizing):
Here's what setting (2) looks like (both image and mask resized with bilinear):
Here's what setting (3) looks like (mask resized with mode='nearest'). Notice the large black area close to the top right where the mask is not covering the object.
And finally here's what setting (4) looks like (mask resized using PIL's resize method with Image.NEAREST mode:
Notice how (3) has a larger misalignment than (4). The fact that (2) is better is not surprising considering that it involves bilinear upsampling, but then I would expect PIL's resize method to perform similarly bad. As you can see though, (2) and (4) are actually similarly good whereas (3) is worse.
Just to confirm, let's look at the plot for the absolute difference between the resized masks produced by (3) and (4):
This might seem like a small difference, but it was wreaking havoc with my segmentation network training. My output masks would always be shifted to the right/bottom which was reducing the final IoU score by ~2%.
P.S. OpenCV's resize method seems to have the same problem as PyTorch's interpolate.
Environment
- PyTorch Version (e.g., 1.0): 1.1.0
- OS (e.g., Linux): Ubuntu 18.04.3 LTS
- How you installed PyTorch (
conda,pip, source): conda - Build command you used (if compiling from source): N/A
- Python version: 3.7
- CUDA/cuDNN version: 10.0.130/7.6.5
- GPU models and configuration: GPU 0: GeForce GTX 980, GPU 1: GeForce GTX TITAN X
- Any other relevant information: N/A




