-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
Using nested tensors generated with torch.narrow as inputs to torch.nn.functional.scaled_dot_product_attention works fine in the forward pass of the model. However, both the math and Flash backends crash when training a model.
When using SDPBackend.MATH, I encounter the following error:
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 13107200 (input tensor's size at dimension 0), but got split_sizes=[131072, 131072, 131072, 131072, 131072, 131072, 131072, 131072, 131072, 131072]
The returned sizes sum up to the original size of the tensor (before torch.narrow).
When using SDPBackend.FLASH_ATTENTION, the problem seems to be that there is no available backward implementation:
RuntimeError: derivative for aten::narrow is not implemented
I would add that although the different nested layouts are not documented, using torch.jagged as the layout for torch.nested.narrow is not possible since only dim=1 is allowed in this case:
RuntimeError: jagged layout only supports dim=1
Ideally, slicing the tensors would be the most intuitive. However, the ability to use slice/narrow at least for the 0 dimension and batching purposes seems essential for model training. In my case, it is not easy/efficient to do the slicing before batching.
Perhaps @jbschlosser @cpuhrsch know more about this based on previous contributions.
A minimal example to reproduce these issues:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from tqdm.auto import tqdm
torch.manual_seed(0)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc_q = nn.Linear(256, 256)
self.fc_k = nn.Linear(256, 256)
self.fc_v = nn.Linear(256, 256)
self.fc_o = nn.Linear(256, 5)
def forward(self, x):
q = self.fc_q(x)
k = self.fc_k(x)
v = self.fc_v(x)
q_nested = torch.nested.as_nested_tensor(q)
k_nested = torch.nested.as_nested_tensor(k)
v_nested = torch.nested.as_nested_tensor(v)
q = q_nested.reshape(100, 512, 16, 16).transpose(1, 2).contiguous()
k = k_nested.reshape(100, 512, 16, 16).transpose(1, 2).contiguous()
v = v_nested.reshape(100, 512, 16, 16).transpose(1, 2).contiguous()
q_narrow = torch.narrow(q, dim=0, start=0, length=10)
k_narrow = torch.narrow(k, dim=0, start=0, length=10)
v_narrow = torch.narrow(v, dim=0, start=0, length=10)
# with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
with sdpa_kernel(SDPBackend.MATH):
out = F.scaled_dot_product_attention(
q_narrow, k_narrow, v_narrow
)
out = out.transpose(1, 2).contiguous()
out = torch.nested.to_padded_tensor(out, padding=0)
out = out.reshape(10, 512, 256)
return out
def main():
x = torch.randn((100, 512, 256), requires_grad=True).cuda()
y = torch.randint(0, 5, size=(10,)).cuda()
model = Model().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scaler = torch.amp.GradScaler()
for epoch in tqdm(range(100)):
optimizer.zero_grad()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
predictions = model(x).squeeze().sum(dim=1).sum(dim=-1)
loss = F.cross_entropy(predictions, y.float())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if __name__ == "__main__":
main()Versions
python collect_env.py
Collecting environment information...
PyTorch version: 2.5.0.dev20240909
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-121-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 9 5950X 16-Core Processor
CPU family: 25
Model: 33
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU max MHz: 3400,0000
CPU min MHz: 2200,0000
BogoMIPS: 6800.18
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization: AMD-V
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 8 MiB (16 instances)
L3 cache: 64 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] admin-torch==0.1.0
[pip3] numpy==2.1.1
[pip3] performer-pytorch==1.1.4
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.5.0.dev20240909
[pip3] torch_cluster==1.6.3
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[pip3] torch_spline_conv==1.2.2
[pip3] torchaudio==2.5.0.dev20240909
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.4.1
[pip3] torchscale==0.2.0
[pip3] torchvision==0.20.0.dev20240909
[pip3] triton==3.0.0
[conda] admin-torch 0.1.0 pypi_0 pypi
[conda] blas 1.0 mkl conda-forge
[conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly
[conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly
[conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly
[conda] filelock 3.9.0 py311_0 pytorch-nightly
[conda] libblas 3.9.0 16_linux64_mkl conda-forge
[conda] libcblas 3.9.0 16_linux64_mkl conda-forge
[conda] liblapack 3.9.0 16_linux64_mkl conda-forge
[conda] mkl 2022.1.0 hc2b9512_224
[conda] mpmath 1.2.1 py311_0 pytorch-nightly
[conda] numpy 2.1.1 py311h71ddf71_0 conda-forge
[conda] performer-pytorch 1.1.4 pypi_0 pypi
[conda] pysocks 1.7.1 py311_0 pytorch-nightly
[conda] pytorch 2.5.0.dev20240909 py3.11_cuda12.1_cudnn9.1.0_0 pytorch-nightly
[conda] pytorch-cuda 12.1 ha16c6d3_6 pytorch-nightly
[conda] pytorch-lightning 2.4.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torch-cluster 1.6.3 pypi_0 pypi
[conda] torch-scatter 2.1.2 pypi_0 pypi
[conda] torch-sparse 0.6.18 pypi_0 pypi
[conda] torch-spline-conv 1.2.2 pypi_0 pypi
[conda] torchaudio 2.5.0.dev20240909 py311_cu121 pytorch-nightly
[conda] torchinfo 1.8.0 pypi_0 pypi
[conda] torchmetrics 1.4.1 pypi_0 pypi
[conda] torchscale 0.2.0 dev_0 <develop>
[conda] torchtriton 3.0.0+757b6a61e7 py311 pytorch-nightly
[conda] torchvision 0.20.0.dev20240909 py311_cu121 pytorch-nightly
[conda] urllib3 1.26.14 py311_0 pytorch-nightly
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @erichan1 @mikaylagawarecki