Skip to content

dynamo creates unnecessary buffers #124653

@isuruf

Description

@isuruf

🐛 Describe the bug

Following creates two triton kernels instead of one. In fact the buf0 is unused.

import torch, itertools

n = 200
a = torch.randn((n, n, n), device='cuda')

def fn(a):
    t = (a, a + 1, a + 2)
    shape = a.shape
    m = 1
    for dim in range(len(shape)):
        view_shape = [1]*(dim + 1)
        view_shape[dim] = -1
        b = (torch.arange(shape[dim], device=a.device).view(view_shape))
        m = torch.mul(m, b)
    return sum(torch.mul(t1, m) for t1 in t)
    
torch.compile(fn)(a);
Details
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_isuruf/gw/cgwpfgkdk4yia3bpt37vqjdjda6rp635xory4heq5y2ea46iaqyt.py
# Source Nodes: [arange, b, m, m_1, m_2], Original ATen: [aten.arange, aten.mul, aten.view]
# arange => iota
# b => view
# m => mul
# m_1 => mul_1
# m_2 => mul_2
triton_poi_fused_arange_mul_view_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[256], 
    filename=__file__,
    triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_arange_mul_view_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 200
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = x0*x0*x0
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


# kernel path: /tmp/torchinductor_isuruf/vx/cvx4wnrm6ktqek4r6kmhrqqb3nlid4qanygazjhwmvx26svg4yl7.py
# Source Nodes: [add_2, add_3, add_4, mul_3, mul_4, mul_5, t1, t1_1], Original ATen: [aten.add, aten.mul]
# add_2 => add_2
# add_3 => add_3
# add_4 => add_4
# mul_3 => mul_3
# mul_4 => mul_4
# mul_5 => mul_5
# t1 => add
# t1_1 => add_1
triton_poi_fused_add_mul_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[8388608], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_1', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 8000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = xindex % 200
    tmp0 = tl.load(in_ptr0 + (x2), xmask)
    tmp1 = x0*x0*x0
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 * tmp2
    tmp4 = 0.0
    tmp5 = tmp3 + tmp4
    tmp6 = 1.0
    tmp7 = tmp0 + tmp6
    tmp8 = tmp7 * tmp2
    tmp9 = tmp5 + tmp8
    tmp10 = 2.0
    tmp11 = tmp0 + tmp10
    tmp12 = tmp11 * tmp2
    tmp13 = tmp9 + tmp12
    tl.store(out_ptr0 + (x2), tmp13, xmask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (200, 200, 200), (40000, 200, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1, 1, 200), (200, 200, 1), torch.int64)
        # Source Nodes: [arange, b, m, m_1, m_2], Original ATen: [aten.arange, aten.mul, aten.view]
        stream0 = get_raw_stream(0)
        triton_poi_fused_arange_mul_view_0.run(buf0, 200, grid=grid(200), stream=stream0)
        buf1 = empty_strided_cuda((200, 200, 200), (40000, 200, 1), torch.float32)
        # Source Nodes: [add_2, add_3, add_4, mul_3, mul_4, mul_5, t1, t1_1], Original ATen: [aten.add, aten.mul]
        triton_poi_fused_add_mul_1.run(arg0_1, buf1, 8000000, grid=grid(8000000), stream=stream0)
        del arg0_1
    return (buf0, buf1, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((200, 200, 200), (40000, 200, 1), device='cuda:0', dtype=torch.float32)
    fn = lambda: call([arg0_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

Versions

Collecting environment information...
PyTorch version: 2.4.0a0+gitf5ad149
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (conda-forge gcc 12.3.0-2) 12.3.0
Clang version: Could not collect
CMake version: version 3.27.6
Libc version: glibc-2.35

Python version: 3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36)  [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-97-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 2060
GPU 1: NVIDIA GeForce RTX 2060

Nvidia driver version: 545.23.08
cuDNN version: Probably one of the following:
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_adv.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_cnn.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_engines_precompiled.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_engines_runtime_compiled.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_graph.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_heuristic.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_ops.so.9
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      43 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             64
On-line CPU(s) list:                0-63
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen Threadripper 3970X 32-Core Processor
CPU family:                         23
Model:                              49
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          1
Stepping:                           0
Frequency boost:                    enabled
CPU max MHz:                        3700.0000
CPU min MHz:                        2200.0000
BogoMIPS:                           7400.38
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization:                     AMD-V
L1d cache:                          1 MiB (32 instances)
L1i cache:                          1 MiB (32 instances)
L2 cache:                           16 MiB (32 instances)
L3 cache:                           128 MiB (8 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] onnx==1.15.0
[pip3] onnxruntime==1.17.0
[pip3] onnxscript==0.1.0.dev20240117
[pip3] optree==0.11.0
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.3.0a0+gitf9f602f
[pip3] torchvision==0.16.1+cf89794
[conda] libmagma                  2.7.2                h173bb3b_0    conda-forge
[conda] libmagma_sparse           2.7.2                h173bb3b_0    conda-forge
[conda] libopenvino-pytorch-frontend 2023.1.0             h59595ed_1    conda-forge
[conda] magma                     2.7.2                h51420fd_0    conda-forge
[conda] mkl                       2022.2.1         h84fe81f_16997    conda-forge
[conda] mkl-include               2023.2.0         h84fe81f_50495    conda-forge
[conda] numpy                     1.24.3                   pypi_0    pypi
[conda] optree                    0.11.0                   pypi_0    pypi
[conda] pytorch-triton            3.0.0+989adb9a29          pypi_0    pypi
[conda] torch                     2.4.0a0+giteae45c6           dev_0    <develop>
[conda] torchfix                  0.4.0                    pypi_0    pypi
[conda] torchvision               0.16.1          cpu_py38h901811f_1    conda-forge

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions