-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
OSS contribution wantedPR from open source contributors welcome to solve this issue.PR from open source contributors welcome to solve this issue.good first issueonnx-triagedtriaged by ONNX teamtriaged by ONNX team
Description
🐛 Describe the bug
ONNX Opset 16 GridSample does not support 5D volumetric input tensor. However, PyTorch can still export it. The correct behavior should be preventing the GridSample export when the input tensor is 5D volumetric. The exported ONNX file is also attached.
import torch
class GridSample(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_tensor, grid_tensor):
output_tensor = torch.nn.functional.grid_sample(input=input_tensor,
grid=grid_tensor,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
return output_tensor
N = 4
C = 32
D_in = 64
H_in = 64
W_in = 64
D_out = 128
H_out = 128
W_out = 128
input_shape = [N, C, D_in, H_in, W_in]
grid_shape = [N, D_out, H_out, W_out, 3]
output_shape = [N, C, D_out, H_out, W_out]
input_tensor = torch.rand(*input_shape)
grid_tensor = torch.rand(*grid_shape)
grid_sample_module = GridSample()
output_tensor = grid_sample_module(input_tensor=input_tensor, grid_tensor=grid_tensor)
assert list(output_tensor.shape) == output_shape
torch.onnx.export(grid_sample_module,
{"input_tensor": input_tensor, "grid_tensor": grid_tensor},
"grid_sample.onnx",
verbose=False,
opset_version=16)Versions
PyTorch version: 1.13.0a0+936e930
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.24.1
Libc version: glibc-2.31
Python version: 3.8.10 (default, Jun 22 2022, 20:18:18) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 515.86.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] functorch==1.13.0a0+936e930
[pip3] numpy==1.22.2
[pip3] pytorch-quantization==2.1.2
[pip3] torch==1.13.0a0+936e930
[pip3] torch-tensorrt==1.3.0a0
[pip3] torchtext==0.13.0a0+fae8e8c
[pip3] torchvision==0.15.0a0
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
OSS contribution wantedPR from open source contributors welcome to solve this issue.PR from open source contributors welcome to solve this issue.good first issueonnx-triagedtriaged by ONNX teamtriaged by ONNX team