Skip to content

[inductor] [cpu] [graph optimization] output size calculation behaves differently of ConvTranspose1d, ConvTranspose2d, ConvTranspose3d along with sigmoid #144013

@shaoyuyoung

Description

@shaoyuyoung

🐛 Describe the bug

symptom:
trigger condition1: for ConvTranspose, the output size calculation formula is: $O=(I−1)×S+K−2×P$. When the O is zero, eager will raise the error but the inductor will pass the check and return an empty tensor.
trigger condition2: ConvTranspose must be along with sigmoid. otherwise, inductor will also raise the error

I guess the optimization of ConvTranspose-sigmoid on CPP backend loose the check for output size.

device: cpu only

exposed area: ConvTranspose1d, ConvTranspose2d, ConvTranspose3d

import torch


class Model(torch.nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.conv_t = eval(f"torch.nn.ConvTranspose{dim}d(1, 1, kernel_size=(2,) * {dim}, padding=(1,) * {dim})")

    def forward(self, x):
        x = self.conv_t(x)
        x = torch.sigmoid(x)  # tigger condition
        return x


def run_test(dim, mode):
    x = torch.randn(*([1] * (dim + 2)))

    inputs = [x]
    model = Model(dim)

    if mode == "inductor":
        model = torch.compile(model)

    try:
        output = model(*inputs)
        print(f"success on {mode}: {output}")
    except Exception as e:
        print(e)


run_test(1, "eager")
run_test(1, "inductor")

run_test(2, "eager")
run_test(2, "inductor")

run_test(3, "eager")
run_test(3, "inductor")

Error logs

Given input size per channel: (1 x 1). Calculated output size per channel: (1 x 0). Output size is too small
success on inductor: tensor([], size=(1, 1, 0), grad_fn=<CompiledFunctionBackward>)
Given input size per channel: (1 x 1). Calculated output size per channel: (0 x 0). Output size is too small
success on inductor: tensor([], size=(1, 1, 0, 0), grad_fn=<CompiledFunctionBackward>)
Given input size per channel: (1 x 1 x 1). Calculated output size per channel: (0 x 0 x 0). Output size is too small
success on inductor: tensor([], size=(1, 1, 0, 0, 0), grad_fn=<CompiledFunctionBackward>)

Versions

PyTorch version: 20241230
OS: Ubuntu 20.04.6 LTS (x86_64)
CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
GPU: V100

click for detailed env

PyTorch version: 2.6.0.dev20241230+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 16.0.1
CMake version: version 3.26.0
Libc version: glibc-2.31

Python version: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct  4 2024, 13:27:36) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-204-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB
GPU 2: Tesla V100-SXM2-32GB
GPU 3: Tesla V100-SXM2-32GB

Nvidia driver version: 560.35.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.6.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      40 bits physical, 48 bits virtual
CPU(s):                             20
On-line CPU(s) list:                0-19
Thread(s) per core:                 1
Core(s) per socket:                 20
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              85
Model name:                         Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
Stepping:                           7
CPU MHz:                            2499.998
BogoMIPS:                           4999.99
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          640 KiB
L1i cache:                          640 KiB
L2 cache:                           80 MiB
L3 cache:                           16 MiB
NUMA node0 CPU(s):                  0-19
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        KVM: Vulnerable
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Mitigation; IBRS
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch topoext cpuid_fault invpcid_single pti ssbd ibrs ibpb fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat umip pku ospke avx512_vnni

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.17.0
[pip3] onnxruntime==1.20.1
[pip3] onnxscript==0.1.0.dev20241205
[pip3] optree==0.13.1
[pip3] pytorch-triton==3.2.0+git0d4682f0
[pip3] torch==2.6.0.dev20241230+cu126
[pip3] torchaudio==2.6.0.dev20241230+cu126
[pip3] torchvision==0.22.0.dev20241230+cu126
[pip3] triton==3.0.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.6.4.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.6.80                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.6.77                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.6.77                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.5.1.17                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.0.4                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.7.77                pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.1.2                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.4.2                 pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.3                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.6.85                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.6.77                  pypi_0    pypi
[conda] optree                    0.13.1                   pypi_0    pypi
[conda] pytorch-triton            3.2.0+git0d4682f0          pypi_0    pypi
[conda] torch                     2.6.0.dev20241230+cu126          pypi_0    pypi
[conda] torchaudio                2.6.0.dev20241230+cu126          pypi_0    pypi
[conda] torchvision               0.22.0.dev20241230+cu126          pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi

cc @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions