Skip to content

Training (backward) crashes when using torch.narrow, nested tensors, and scaled_dot_product_attention #136270

@davidbuterez

Description

@davidbuterez

🐛 Describe the bug

Using nested tensors generated with torch.narrow as inputs to torch.nn.functional.scaled_dot_product_attention works fine in the forward pass of the model. However, both the math and Flash backends crash when training a model.

When using SDPBackend.MATH, I encounter the following error:

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 13107200 (input tensor's size at dimension 0), but got split_sizes=[131072, 131072, 131072, 131072, 131072, 131072, 131072, 131072, 131072, 131072]

The returned sizes sum up to the original size of the tensor (before torch.narrow).

When using SDPBackend.FLASH_ATTENTION, the problem seems to be that there is no available backward implementation:

RuntimeError: derivative for aten::narrow is not implemented

I would add that although the different nested layouts are not documented, using torch.jagged as the layout for torch.nested.narrow is not possible since only dim=1 is allowed in this case:

RuntimeError: jagged layout only supports dim=1

Ideally, slicing the tensors would be the most intuitive. However, the ability to use slice/narrow at least for the 0 dimension and batching purposes seems essential for model training. In my case, it is not easy/efficient to do the slicing before batching.

Perhaps @jbschlosser @cpuhrsch know more about this based on previous contributions.

A minimal example to reproduce these issues:

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from tqdm.auto import tqdm
 
torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.fc_q = nn.Linear(256, 256)
        self.fc_k = nn.Linear(256, 256)
        self.fc_v = nn.Linear(256, 256)
        self.fc_o = nn.Linear(256, 5)


    def forward(self, x):
        q = self.fc_q(x)
        k = self.fc_k(x)
        v = self.fc_v(x)

        q_nested = torch.nested.as_nested_tensor(q)
        k_nested = torch.nested.as_nested_tensor(k)
        v_nested = torch.nested.as_nested_tensor(v)

        q = q_nested.reshape(100, 512, 16, 16).transpose(1, 2).contiguous()
        k = k_nested.reshape(100, 512, 16, 16).transpose(1, 2).contiguous()
        v = v_nested.reshape(100, 512, 16, 16).transpose(1, 2).contiguous()

        q_narrow = torch.narrow(q, dim=0, start=0, length=10)
        k_narrow = torch.narrow(k, dim=0, start=0, length=10)
        v_narrow = torch.narrow(v, dim=0, start=0, length=10)

        # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        with sdpa_kernel(SDPBackend.MATH):
            out = F.scaled_dot_product_attention(
                q_narrow, k_narrow, v_narrow
            )

        out = out.transpose(1, 2).contiguous()
        out = torch.nested.to_padded_tensor(out, padding=0)
        out = out.reshape(10, 512, 256)

        return out


def main():
    x = torch.randn((100, 512, 256), requires_grad=True).cuda()
    y = torch.randint(0, 5, size=(10,)).cuda()

    model = Model().cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scaler = torch.amp.GradScaler()

    for epoch in tqdm(range(100)):
        optimizer.zero_grad()

        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            predictions = model(x).squeeze().sum(dim=1).sum(dim=-1)
            loss = F.cross_entropy(predictions, y.float())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()


if __name__ == "__main__":
    main()

Versions

python collect_env.py
Collecting environment information...
PyTorch version: 2.5.0.dev20240909
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-121-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               32
On-line CPU(s) list:                  0-31
Vendor ID:                            AuthenticAMD
Model name:                           AMD Ryzen 9 5950X 16-Core Processor
CPU family:                           25
Model:                                33
Thread(s) per core:                   2
Core(s) per socket:                   16
Socket(s):                            1
Stepping:                             0
Frequency boost:                      enabled
CPU max MHz:                          3400,0000
CPU min MHz:                          2200,0000
BogoMIPS:                             6800.18
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                       AMD-V
L1d cache:                            512 KiB (16 instances)
L1i cache:                            512 KiB (16 instances)
L2 cache:                             8 MiB (16 instances)
L3 cache:                             64 MiB (2 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-31
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; safe RET, no microcode
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] admin-torch==0.1.0
[pip3] numpy==2.1.1
[pip3] performer-pytorch==1.1.4
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.5.0.dev20240909
[pip3] torch_cluster==1.6.3
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[pip3] torch_spline_conv==1.2.2
[pip3] torchaudio==2.5.0.dev20240909
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.4.1
[pip3] torchscale==0.2.0
[pip3] torchvision==0.20.0.dev20240909
[pip3] triton==3.0.0
[conda] admin-torch               0.1.0                    pypi_0    pypi
[conda] blas                      1.0                         mkl    conda-forge
[conda] brotlipy                  0.7.0           py311h9bf148f_1002    pytorch-nightly
[conda] cffi                      1.15.1          py311h9bf148f_3    pytorch-nightly
[conda] cryptography              38.0.4          py311h46ebde7_0    pytorch-nightly
[conda] filelock                  3.9.0                   py311_0    pytorch-nightly
[conda] libblas                   3.9.0            16_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            16_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            16_linux64_mkl    conda-forge
[conda] mkl                       2022.1.0           hc2b9512_224
[conda] mpmath                    1.2.1                   py311_0    pytorch-nightly
[conda] numpy                     2.1.1           py311h71ddf71_0    conda-forge
[conda] performer-pytorch         1.1.4                    pypi_0    pypi
[conda] pysocks                   1.7.1                   py311_0    pytorch-nightly
[conda] pytorch                   2.5.0.dev20240909 py3.11_cuda12.1_cudnn9.1.0_0    pytorch-nightly
[conda] pytorch-cuda              12.1                 ha16c6d3_6    pytorch-nightly
[conda] pytorch-lightning         2.4.0                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torch-cluster             1.6.3                    pypi_0    pypi
[conda] torch-scatter             2.1.2                    pypi_0    pypi
[conda] torch-sparse              0.6.18                   pypi_0    pypi
[conda] torch-spline-conv         1.2.2                    pypi_0    pypi
[conda] torchaudio                2.5.0.dev20240909     py311_cu121    pytorch-nightly
[conda] torchinfo                 1.8.0                    pypi_0    pypi
[conda] torchmetrics              1.4.1                    pypi_0    pypi
[conda] torchscale                0.2.0                     dev_0    <develop>
[conda] torchtriton               3.0.0+757b6a61e7           py311    pytorch-nightly
[conda] torchvision               0.20.0.dev20240909     py311_cu121    pytorch-nightly
[conda] urllib3                   1.26.14                 py311_0    pytorch-nightly

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @erichan1 @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions