Skip to content

Using torch.nn.functional.interpolate with 'nearest' mode introduces misalignment #34808

@Ali2500

Description

@Ali2500

🐛 Bug

When resizing images and their corresponding segmentation masks, it is common practice to use bilinear interpolation for the images and nearest neighbor sampling for segmentation masks. However, I have observed that using the interpolate method in torch.nn.functional with mode='nearest' causes the resized segmentation masks to be shifted to the bottom right when compared to the original image.

To Reproduce

This requires PIL and OpenCV to be installed. The code given loads an image and its corresponding segmentation mask and produces 4 visualizations of the mask overlayed on top of the image:

  1. Without any resizing, i.e., control baseline to assess mask accuracy.
  2. Both image and mask are up-sampled using torch.nn.functional.interpolate with mode='bilinear' and align_corners=False.
  3. The image is upsampled as in (2), but the mask is up-sampled with mode='nearest' (this is where the problem occurs).
  4. The image is upsampled as in (2), but the mask is up-sampled using the Image.resize method in PIL. This appears to perform much better than the setting in (3).

The image and mask used can be downloaded from:

Point the IMAGE_PATH and MASK_PATH variables in the code below to the location of these two files.

EDIT: Also added a visualization of the absolute difference between (3) and (4). This required matplotlib.

from PIL import Image

import cv2
import numpy as np
import torch
import torch.nn.functional as F

IMAGE_PATH = "/path/to/image.jpg"
MASK_PATH = "/path/to/mask.png"


def visualize_mask_on_image(image, mask):
    color = np.array([0, 255, 0], np.uint8)[None, None, :]
    masked_image = np.where(mask, color, image).astype(np.float32)
    masked_image = np.round((0.6 * masked_image) + (0.4 * image.astype(np.float32)))
    return masked_image.astype(np.uint8)


def resize_torch(interpolation_mode, align_corners):
    # open image and mask files
    image = cv2.imread(IMAGE_PATH, cv2.IMREAD_COLOR)
    mask = cv2.imread(MASK_PATH, cv2.IMREAD_UNCHANGED)

    # convert to torch tensor
    image = torch.from_numpy(image).permute(2, 0, 1).float()  # HWC to CHW and byte to float
    mask = torch.from_numpy(mask > 0).unsqueeze(0).float()

    # upsample image and mask by 4x
    image = F.interpolate(image.unsqueeze(0), size=(800, 1200), mode='bilinear', align_corners=False)
    mask = F.interpolate(mask.unsqueeze(0), size=(800, 1200), mode=interpolation_mode, align_corners=align_corners)

    # convert to numpy arrays
    image = image.squeeze(0).permute(1, 2, 0).byte().numpy()
    mask = (mask.squeeze(0).permute(1, 2, 0) > 0.5).byte().numpy()

    # overlay the mask on the image
    return visualize_mask_on_image(image, mask)


def resize_pil_nearest():
    # open image and mask files
    image = cv2.imread(IMAGE_PATH, cv2.IMREAD_COLOR)
    mask = Image.open(MASK_PATH)

    # convert segmentation mask to binary fg/bg mask
    colormap = [0] + [1 for _ in range(255)]
    mask = mask.point(colormap)

    # convert to torch tensor
    image = torch.from_numpy(image).permute(2, 0, 1).float()  # HWC to CHW and byte to float

    # upsample image and mask by 4x
    image = F.interpolate(image.unsqueeze(0), size=(800, 1200), mode='bilinear', align_corners=False)
    mask = np.array(mask.resize((1200, 800), resample=Image.NEAREST))

    # convert to numpy arrays
    image = image.squeeze(0).permute(1, 2, 0).byte().numpy()
    mask = mask[:, :, None]

    # overlay the mask on the image
    return visualize_mask_on_image(image, mask), mask


def no_resize_baseline():
    # open image and mask files
    image = cv2.imread(IMAGE_PATH, cv2.IMREAD_COLOR)
    mask = cv2.imread(MASK_PATH, cv2.IMREAD_UNCHANGED)[:, :, None]

    return visualize_mask_on_image(image, mask), mask


def plot_torch_pil_diff():
    _, mask_pil = resize_pil_nearest()
    _, mask_torch = resize_torch('nearest', None)

    mask_pil = mask_pil.astype(np.float32)
    mask_torch = mask_torch.astype(np.float32)

    diff = np.abs(mask_pil - mask_torch)[:, :, 0]
    diff_img = plt.imshow(diff)
    plt.colorbar(diff_img)
    plt.show()


resized_torch_bilinear, _ = resize_torch('bilinear', False)
resized_torch_nearest, _ = resize_torch('nearest', None)
resized_PIL_nearest, _ = resize_pil_nearest()
no_resize, _ = no_resize_baseline()

cv2.namedWindow('Torch Bilinear', cv2.WINDOW_NORMAL)
cv2.namedWindow('Torch Nearest', cv2.WINDOW_NORMAL)
cv2.namedWindow('PIL Nearest', cv2.WINDOW_NORMAL)
cv2.namedWindow('No Resize', cv2.WINDOW_NORMAL)

cv2.imshow('Torch Bilinear', resized_torch_bilinear)
cv2.imshow('Torch Nearest', resized_torch_nearest)
cv2.imshow('PIL Nearest', resized_PIL_nearest)
cv2.imshow('No Resize', no_resize)

plot_torch_pil_diff()
cv2.waitKey(0)

Expected behavior

Settings (3) and (4) should result in roughly the same output, but they don't. Let's zoom into handle bars of the bicycles in the overlayed images to get a better idea:

Here's what setting (1) looks like (no resizing):

No Resize_screenshot_16 03 2020

Here's what setting (2) looks like (both image and mask resized with bilinear):

Torch Bilinear_screenshot_16 03 2020

Here's what setting (3) looks like (mask resized with mode='nearest'). Notice the large black area close to the top right where the mask is not covering the object.

Torch Nearest_screenshot_16 03 2020

And finally here's what setting (4) looks like (mask resized using PIL's resize method with Image.NEAREST mode:

PIL Nearest_screenshot_16 03 2020

Notice how (3) has a larger misalignment than (4). The fact that (2) is better is not surprising considering that it involves bilinear upsampling, but then I would expect PIL's resize method to perform similarly bad. As you can see though, (2) and (4) are actually similarly good whereas (3) is worse.

Just to confirm, let's look at the plot for the absolute difference between the resized masks produced by (3) and (4):

diff

This might seem like a small difference, but it was wreaking havoc with my segmentation network training. My output masks would always be shifted to the right/bottom which was reducing the final IoU score by ~2%.

P.S. OpenCV's resize method seems to have the same problem as PyTorch's interpolate.

Environment

  • PyTorch Version (e.g., 1.0): 1.1.0
  • OS (e.g., Linux): Ubuntu 18.04.3 LTS
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): N/A
  • Python version: 3.7
  • CUDA/cuDNN version: 10.0.130/7.6.5
  • GPU models and configuration: GPU 0: GeForce GTX 980, GPU 1: GeForce GTX TITAN X
  • Any other relevant information: N/A

cc @albanD @mruberry @jbschlosser @fmassa @vfdev-5 @pmeier

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: interpolationmodule: nnRelated to torch.nnmodule: visiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions