Skip to content

Failure in generating a kernel with 3 tile groups #141121

@kundaMwiza

Description

@kundaMwiza

🐛 Describe the bug

Trying to generate a pointwise add kernel with 3 tiling groups fails.

Reproducer:

import torch
from torch._inductor.utils import run_and_get_triton_code
from torch._inductor import config

import functools

config.triton.max_tiles = 3
config.triton.prefer_nd_tiling = True

full_size, view_size, num_block_pointers, num_tiles = (
    (5, 5, 5, 5, 5),
    (3, 3, 5, 3, 5),
    1,
    2,
)

GPU_TYPE = "cuda"


def get_input() -> torch.Tensor:
    device = torch.device(GPU_TYPE)
    full = torch.randn(full_size).to(device)
    return torch.as_strided(full, view_size, full.stride())


a, b = get_input(), get_input()

opt_fn = torch.compile(functools.partial(torch.add))
code = run_and_get_triton_code(opt_fn, a, b)

This error occurs because a LoopBody iteration prefix is z, which matches with the prefix of a range tree for the z dimension.

Interestingly a z prefix type is banned here: pytorch/torch/utils/_sympy/symbol.py at c9c8370feb80290dd47f30395a51902265ac0142 · pytorch/pytorch . And there is no corresponding ZBLOCK here: pytorch/torch/utils/_sympy/symbol.py at c9c8370feb80290dd47f30395a51902265ac0142 · pytorch/pytorch

Error logs

 def codegen_sync(self):

/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/simd.py in codegen_node(self, node)
   1199         schedule_log.debug("Schedule:\n %s", node_schedule)
   1200 
-> 1201         return self.codegen_node_schedule(
   1202             SIMDKernelFeatures(node_schedule, numel, rnumel)
   1203         )

/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/simd.py in codegen_node_schedule(self, kernel_features)
   1239         )
   1240         for kernel in kernels:
-> 1241             self.codegen_node_schedule_with_kernel(node_schedule, kernel)
   1242         MultiKernel.merge_workspaces_inplace(kernels)
   1243         for kernel in kernels:

/usr/local/lib/python3.10/dist-packages/torch/_inductor/codegen/simd.py in codegen_node_schedule_with_kernel(self, node_schedule, kernel)
   1318                     all_indexing.update(
   1319                         dict.fromkeys(
-> 1320                             node._body.indexing_from_args(index_vars).values()
   1321                         )
   1322                     )

/usr/local/lib/python3.10/dist-packages/torch/_inductor/loop_body.py in indexing_from_args(self, indices)
    391         index = [*itertools.chain.from_iterable(indices)]
    392         assert len(index) == len(self.var_ranges), (index, self.var_ranges)
--> 393         assert all(
    394             v not in self.var_ranges for v in index
    395         ), f"{self.var_ranges=}, {indices=}"

BackendCompilerFailed: backend='inductor' raised:
AssertionError: self.var_ranges={z0: 3, z1: 15, z2: 15}, indices=[[z0, y1, x2], []]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20241120+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.30.5
Libc version: glibc-2.35

Python version: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.85+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 535.104.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               2
On-line CPU(s) list:                  0,1
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                           6
Model:                                79
Thread(s) per core:                   2
Core(s) per socket:                   1
Socket(s):                            1
Stepping:                             0
BogoMIPS:                             4399.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            32 KiB (1 instance)
L1i cache:                            32 KiB (1 instance)
L2 cache:                             256 KiB (1 instance)
L3 cache:                             55 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0,1
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Mitigation; PTE Inversion
Vulnerability Mds:                    Vulnerable; SMT Host state unknown
Vulnerability Meltdown:               Vulnerable
Vulnerability Mmio stale data:        Vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Vulnerable
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Vulnerable
Vulnerability Spectre v1:             Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:             Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Vulnerable (Syscall hardening enabled)
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] nvtx==0.2.10
[pip3] optree==0.13.1
[pip3] pynvjitlink-cu12==0.4.0
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241120+cu124
[pip3] torchaudio==2.5.1+cu121
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.20.1+cu121
[conda] Could not collect

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @ezyang @blaine-rister

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions