Skip to content

Multidimensional Tiling with Reduction #142317

@nullplay

Description

@nullplay

🐛 Describe the bug

Hi PyTorch Team,

I found that setting torch._inductor.config.triton.prefer_nd_tiling=True significantly improves my program's performance (up to 5x speedup!). However, this configuration works well when there are no reduction operations. If a reduction is present, tiling fails to occur when the reduction size exceeds the unroll_reductions_threshold.

Here's a minimal example demonstrating the issue:

Code Example

import torch

P = 1000
M = 50

B = 1000
L = 81
U = 2  # reduction dimension
V = 2  # reduction dimension
W = 16

Input1 = torch.rand((B, U, L), dtype=torch.float32, device="cuda")
Input2 = torch.rand((B, V, L), dtype=torch.float32, device="cuda")
Weight = torch.rand((U, V, W, M), dtype=torch.float32, device="cuda")
Output = torch.zeros((B, W, L), dtype=torch.float32, device="cuda")
Meta = torch.rand((P,), dtype=torch.float32, device="cuda")

icrd1 = torch.randint(L, (P,), device="cuda")
icrd2 = torch.randint(L, (P,), device="cuda")
wcrd = torch.randint(M, (P,), device="cuda")
ocrd = torch.randint(L, (P,), device="cuda")

def test2(icrd1, icrd2, wcrd, ocrd, Meta, Input1, Input2, Weight, Output, B, U, V, W):
    Input1_selected = torch.index_select(Input1, 2, icrd1)  # (B, U, P)
    Input2_selected = torch.index_select(Input2, 2, icrd2)  # (B, V, P)
    Weight_selected = torch.index_select(Weight, 3, wcrd)   # (U, V, W, P)

    Input1_expanded = Input1_selected.view(B, U, 1, 1, -1)  # (B, U, 1, 1, P)
    Input2_expanded = Input2_selected.view(B, 1, V, 1, -1)  # (B, 1, V, 1, P)
    Weight_expanded = Weight_selected.view(1, U, V, W, -1)  # (1, U, V, W, P)
    product = Input1_expanded * Input2_expanded * Weight_expanded

    product = torch.sum(product, dim=(1, 2))                # (B, W, P): reduction on U, V
    Meta_expanded = Meta.view(1, 1, -1)                     # (1, 1, P)
    product = product * Meta_expanded                       # (B, W, P)

    Output.index_add_(2, ocrd, product)                     # (B, W, L)
    return Output

torch._inductor.config.triton.prefer_nd_tiling = True
torch._inductor.unroll_reductions_threshold = 8  # default value

compiled2 = torch.compile(test2)
Output = compiled2(icrd1, icrd2, wcrd, ocrd, Meta, Input1, Input2, Weight, Output, B, U, V, W)

Observations

  • Reduction Size: The reduction size is determined by U * V.
  • If U * V is smaller than unroll_reductions_threshold, the code generates tiled Triton kernels.
  • When U * V is large, unrolling fails, and tiling does not occur.

Successfully Tiled Case (Reduction Dimensions U and V Are Small)

@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr1, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 16000
    xnumel = 1000
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    ...

Non-Tiled Case (Reduction Dimensions U and V Are Not Unrolled)

@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 16000000
    rnumel = 64
    RBLOCK: tl.constexpr = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    ...
    tmp25 = tl.sum(tmp24, 1)[:, None]
    ...

Questions

  1. Tiling with Reductions:
    Is it possible to enable tiling when there is a reduction operation, especially when the reduction size (U * V) exceeds unroll_reductions_threshold?

  2. Multi-Dimensional Tiling:
    Can Triton support tiling more than two dimensions (i.e., 3 or more dimensions)?

I found #137243 and #141709 which seem relevant to this issue.
@jansel

Error logs

No response

Versions

Collecting environment information...
PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.31.0-rc1
Libc version: glibc-2.31

Python version: 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      39 bits physical, 48 bits virtual
CPU(s):                             20
On-line CPU(s) list:                0-19
Thread(s) per core:                 2
Core(s) per socket:                 10
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              165
Model name:                         Intel(R) Core(TM) i9-10900K CPU @ 3.70GHz
Stepping:                           5
CPU MHz:                            3700.000
CPU max MHz:                        5300.0000
CPU min MHz:                        800.0000
BogoMIPS:                           7399.70
Virtualization:                     VT-x
L1d cache:                          320 KiB
L1i cache:                          320 KiB
L2 cache:                           2.5 MiB
L3 cache:                           20 MiB
NUMA node0 CPU(s):                  0-19
Vulnerability Gather data sampling: Vulnerable: No microcode
Vulnerability Itlb multihit:        KVM: Mitigation: VMX disabled
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT vulnerable
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Vulnerable: No microcode
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp pku ospke md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] jax-triton==0.2.0
[pip3] numpy==2.0.1
[pip3] nvidia-cublas-cu12==12.6.3.3
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-nccl-cu12==2.23.4
[pip3] nvidia-nvjitlink-cu12==12.6.77
[pip3] torch==2.5.1
[pip3] torch_cluster==1.6.3+pt24cu121
[pip3] torch-geometric==2.6.1
[pip3] torch_scatter==2.1.2+pt24cu121
[pip3] torch_sparse==0.6.18+pt24cu121
[pip3] torch_spline_conv==1.2.2+pt24cu121
[pip3] torchaudio==2.5.1
[pip3] torchsparse==2.1.0
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] blas                      1.0                         mkl
[conda] cuda-cudart               12.1.55                       0    nvidia/label/cuda-12.1.0
[conda] cuda-cudart-dev           12.1.55                       0    nvidia/label/cuda-12.1.0
[conda] cuda-cudart-static        12.1.55                       0    nvidia/label/cuda-12.1.0
[conda] cuda-cupti                12.1.62                       0    nvidia/label/cuda-12.1.0
[conda] cuda-cupti-static         12.1.62                       0    nvidia/label/cuda-12.1.0
[conda] cuda-libraries            12.1.0                        0    nvidia/label/cuda-12.1.0
[conda] cuda-libraries-dev        12.1.0                        0    nvidia/label/cuda-12.1.0
[conda] cuda-libraries-static     12.1.0                        0    nvidia/label/cuda-12.1.0
[conda] cuda-nvrtc                12.1.55                       0    nvidia/label/cuda-12.1.0
[conda] cuda-nvrtc-dev            12.1.55                       0    nvidia/label/cuda-12.1.0
[conda] cuda-nvrtc-static         12.1.55                       0    nvidia/label/cuda-12.1.0
[conda] cuda-nvtx                 12.1.66                       0    nvidia/label/cuda-12.1.0
[conda] cuda-opencl               12.1.56                       0    nvidia/label/cuda-12.1.0
[conda] cuda-opencl-dev           12.1.56                       0    nvidia/label/cuda-12.1.0
[conda] cuda-runtime              12.1.0                        0    nvidia
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] jax-triton                0.2.0                    pypi_0    pypi
[conda] libcublas                 12.1.0.26                     0    nvidia/label/cuda-12.1.0
[conda] libcublas-dev             12.1.0.26                     0    nvidia/label/cuda-12.1.0
[conda] libcublas-static          12.1.0.26                     0    nvidia/label/cuda-12.1.0
[conda] libcufft                  11.0.2.4                      0    nvidia/label/cuda-12.1.0
[conda] libcufft-dev              11.0.2.4                      0    nvidia/label/cuda-12.1.0
[conda] libcufft-static           11.0.2.4                      0    nvidia/label/cuda-12.1.0
[conda] libcurand                 10.3.2.56                     0    nvidia/label/cuda-12.1.0
[conda] libcurand-dev             10.3.2.56                     0    nvidia/label/cuda-12.1.0
[conda] libcurand-static          10.3.2.56                     0    nvidia/label/cuda-12.1.0
[conda] libcusolver               11.4.4.55                     0    nvidia/label/cuda-12.1.0
[conda] libcusolver-dev           11.4.4.55                     0    nvidia/label/cuda-12.1.0
[conda] libcusolver-static        11.4.4.55                     0    nvidia/label/cuda-12.1.0
[conda] libcusparse               12.0.2.55                     0    nvidia/label/cuda-12.1.0
[conda] libcusparse-dev           12.0.2.55                     0    nvidia/label/cuda-12.1.0
[conda] libcusparse-static        12.0.2.55                     0    nvidia/label/cuda-12.1.0
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] libnvjitlink              12.1.105                      0    nvidia
[conda] libnvjitlink-dev          12.1.55                       0    nvidia/label/cuda-12.1.0
[conda] mkl                       2023.1.0         h213fc3f_46344
[conda] mkl-fft                   1.3.11                   pypi_0    pypi
[conda] mkl-random                1.2.8                    pypi_0    pypi
[conda] mkl-service               2.4.0                    pypi_0    pypi
[conda] mkl_fft                   1.3.11          py310h5eee18b_0
[conda] mkl_random                1.2.8           py310h1128e8f_0
[conda] numpy                     2.0.1                    pypi_0    pypi
[conda] numpy-base                2.0.1           py310hb5e798b_1
[conda] nvidia-cublas-cu12        12.6.3.3                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.6.80                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.6.77                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.5.1.17                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.0.4                 pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.1.2                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.4.2                 pypi_0    pypi
[conda] nvidia-nccl-cu12          2.23.4                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.6.77                  pypi_0    pypi
[conda] pytorch                   2.5.1           py3.10_cuda12.1_cudnn9.1.0_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_6    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torch-cluster             1.6.3+pt24cu121          pypi_0    pypi
[conda] torch-geometric           2.6.1                    pypi_0    pypi
[conda] torch-scatter             2.1.2+pt24cu121          pypi_0    pypi
[conda] torch-sparse              0.6.18+pt24cu121          pypi_0    pypi
[conda] torch-spline-conv         1.2.2+pt24cu121          pypi_0    pypi
[conda] torchaudio                2.5.1                    pypi_0    pypi
[conda] torchsparse               2.1.0                    pypi_0    pypi
[conda] torchtriton               3.1.0                     py310    pytorch
[conda] torchvision               0.20.1                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions