-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
I have identified a significant memory usage regression in the F.conv3d operation in PyTorch 2.9.0 when using bfloat16 inputs.
A benchmark comparing 2.8.0 and 2.9.0 reveals a substantial increase in peak GPU memory allocation for the bfloat16 implementation in the latest version. The memory usage for float32 inputs remains consistent and unchanged between the two releases.
In version 2.8.0, the bfloat16 operation correctly consumes less memory than its float32 counterpart. However, in 2.9.0, the bfloat16 operation consumes significantly more memory than in 2.8.0, and unexpectedly, almost 3x more memory than the float32 operation. This suggests a potential bug or performance regression in thebfloat16kernel for conv3d.
To reproduce:
The following script can be used to reproduce the issue. The only change required is the PyTorch version installed in the environment and input data type.
import torch
import torch.nn.functional as F
print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
input_shape = (1, 96, 6, 626, 626)
weight_shape = (96, 96, 3, 3, 3)
bias_shape = (96,)
input_tensor = torch.randn(input_shape, device=device).type(torch.bfloat16)
weight = torch.randn(weight_shape, device=device).type(torch.bfloat16)
bias = torch.randn(bias_shape, device=device).type(torch.bfloat16)
torch.cuda.synchronize()
mem_before_op_gb = torch.cuda.max_memory_allocated(device) / (1024**3)
print(f"Peak memory allocated before operations: {mem_before_op_gb:.4f} GB")
for _ in range(100):
conv3d_output = F.conv3d(
input=input_tensor,
weight=weight,
bias=bias,
stride=(1, 1, 1),
padding=(0, 0, 0),
dilation=(1, 1, 1)
)
torch.cuda.synchronize()
mem_after_op_gb = torch.cuda.max_memory_allocated(device) / (1024**3)
print(f"Peak memory allocated after operations: {mem_after_op_gb:.4f} GB")
print(f"Memory usage for all operations: {(mem_after_op_gb - mem_before_op_gb):.4f} GB")torch==2.8.0 results:
float32 results:
PyTorch version: 2.8.0+cu128
Peak memory allocated before operations: 0.8427 GB
Peak memory allocated after operations: 3.3555 GB
Memory usage for all operations: 2.5128 GB
bfloat16 results:
PyTorch version: 2.8.0+cu128
Peak memory allocated before operations: 1.2622 GB
Peak memory allocated after operations: 1.6773 GB
Memory usage for all operations: 0.4151 GB
torch==2.9.0 results:
float32 results:
PyTorch version: 2.9.0+cu128
Peak memory allocated before operations: 0.8427 GB
Peak memory allocated after operations: 3.3555 GB
Memory usage for all operations: 2.5128 GB
bfloat16 results:
PyTorch version: 2.9.0+cu128
Peak memory allocated before operations: 1.2622 GB
Peak memory allocated after operations: 8.5288 GB
Memory usage for all operations: 7.2665 GB
Versions
Collecting environment information...
PyTorch version: 2.9.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 14.2.0-4ubuntu2~24.04) 14.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39
Python version: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-85-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 580.95.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.14.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.14.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 30
On-line CPU(s) list: 0-29
Vendor ID: GenuineIntel
BIOS Vendor ID: QEMU
Model name: Intel(R) Xeon(R) Platinum 8462Y+
BIOS Model name: pc-q35-6.2 CPU @ 2.0GHz
BIOS CPU family: 1
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 1
Socket(s): 30
Stepping: 8
BogoMIPS: 5600.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 960 KiB (30 instances)
L1i cache: 960 KiB (30 instances)
L2 cache: 120 MiB (30 instances)
L3 cache: 480 MiB (30 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-29
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
cc @seemethere @malfet @atalman @ptrblck @msaroufim @eqy @jerryzh168