Skip to content

Type mismatch error with torch.nn.functional.grid_sample() under AMP #42218

@jh-jeong

Description

@jh-jeong

🐛 Bug

Found training a model with grid_sample(input, grid) throws the following error under AMP:

RuntimeError: grid_sampler(): expected input and grid to have same dtype, but input has c10::Half and grid has float
  • This also happens in a form (input, grid) = (float, c10:Half), depending on the model definition.
  • I'm not sure, however, how could I reproduce this error in a minimal code snippet.
  • Casting both (input, grid) -> (input.float(), grid.float()) could bypass this issue.

To Reproduce

Steps to reproduce the behavior:

  1. Construct a complex model including grid_sample()
  2. Run it under with autocast():

Expected behavior

AMP is expected to handle every native functions properly in an autocast-enabled region.

Environment

Collecting environment information...
PyTorch version: 1.6.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.8
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp

Nvidia driver version: 418.87.00
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.2
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.4
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.6
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.7

Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] torch==1.6.0
[pip] torchlars==0.1.2
[pip] torchvision==0.7.0
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] mkl                       2020.1                      217
[conda] mkl-service               2.3.0            py38he904b0f_0
[conda] mkl_fft                   1.0.15           py38ha843d7b_0
[conda] mkl_random                1.1.0            py38h962f231_0
[conda] numpy                     1.18.1           py38h4f9e942_0
[conda] numpy-base                1.18.1           py38hde5b4d6_1
[conda] pytorch                   1.6.0           py3.8_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] torchlars                 0.1.2                    pypi_0    pypi
[conda] torchvision               0.7.0                py38_cu101    pytorch

cc @mcarilli

Metadata

Metadata

Assignees

Labels

module: amp (automated mixed precision)autocasttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions