-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
dynamo-triage-june2024dynamo-variable-trackerinternal ramp-up taskTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Following creates two triton kernels instead of one. In fact the buf0 is unused.
import torch, itertools
n = 200
a = torch.randn((n, n, n), device='cuda')
def fn(a):
t = (a, a + 1, a + 2)
shape = a.shape
m = 1
for dim in range(len(shape)):
view_shape = [1]*(dim + 1)
view_shape[dim] = -1
b = (torch.arange(shape[dim], device=a.device).view(view_shape))
m = torch.mul(m, b)
return sum(torch.mul(t1, m) for t1 in t)
torch.compile(fn)(a);Details
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_isuruf/gw/cgwpfgkdk4yia3bpt37vqjdjda6rp635xory4heq5y2ea46iaqyt.py
# Source Nodes: [arange, b, m, m_1, m_2], Original ATen: [aten.arange, aten.mul, aten.view]
# arange => iota
# b => view
# m => mul
# m_1 => mul_1
# m_2 => mul_2
triton_poi_fused_arange_mul_view_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
@triton_heuristics.pointwise(
size_hints=[256],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_arange_mul_view_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 200
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = x0*x0*x0
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_isuruf/vx/cvx4wnrm6ktqek4r6kmhrqqb3nlid4qanygazjhwmvx26svg4yl7.py
# Source Nodes: [add_2, add_3, add_4, mul_3, mul_4, mul_5, t1, t1_1], Original ATen: [aten.add, aten.mul]
# add_2 => add_2
# add_3 => add_3
# add_4 => add_4
# mul_3 => mul_3
# mul_4 => mul_4
# mul_5 => mul_5
# t1 => add
# t1_1 => add_1
triton_poi_fused_add_mul_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_1', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '13d5379970a55f2f2c4bb8dbeb907c03d2af7e5fb1c9d1b1aa5bf5794d5f2277', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 8000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x0 = xindex % 200
tmp0 = tl.load(in_ptr0 + (x2), xmask)
tmp1 = x0*x0*x0
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 * tmp2
tmp4 = 0.0
tmp5 = tmp3 + tmp4
tmp6 = 1.0
tmp7 = tmp0 + tmp6
tmp8 = tmp7 * tmp2
tmp9 = tmp5 + tmp8
tmp10 = 2.0
tmp11 = tmp0 + tmp10
tmp12 = tmp11 * tmp2
tmp13 = tmp9 + tmp12
tl.store(out_ptr0 + (x2), tmp13, xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (200, 200, 200), (40000, 200, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1, 1, 200), (200, 200, 1), torch.int64)
# Source Nodes: [arange, b, m, m_1, m_2], Original ATen: [aten.arange, aten.mul, aten.view]
stream0 = get_raw_stream(0)
triton_poi_fused_arange_mul_view_0.run(buf0, 200, grid=grid(200), stream=stream0)
buf1 = empty_strided_cuda((200, 200, 200), (40000, 200, 1), torch.float32)
# Source Nodes: [add_2, add_3, add_4, mul_3, mul_4, mul_5, t1, t1_1], Original ATen: [aten.add, aten.mul]
triton_poi_fused_add_mul_1.run(arg0_1, buf1, 8000000, grid=grid(8000000), stream=stream0)
del arg0_1
return (buf0, buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((200, 200, 200), (40000, 200, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Versions
Collecting environment information...
PyTorch version: 2.4.0a0+gitf5ad149
Is debug build: False
CUDA used to build PyTorch: 12.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (conda-forge gcc 12.3.0-2) 12.3.0
Clang version: Could not collect
CMake version: version 3.27.6
Libc version: glibc-2.35
Python version: 3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-97-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2060
GPU 1: NVIDIA GeForce RTX 2060
Nvidia driver version: 545.23.08
cuDNN version: Probably one of the following:
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.7.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_adv.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_cnn.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_engines_precompiled.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_engines_runtime_compiled.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_graph.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_heuristic.so.9
/usr/local/cuda-12.3.2/targets/x86_64-linux/lib/libcudnn_ops.so.9
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Vendor ID: AuthenticAMD
Model name: AMD Ryzen Threadripper 3970X 32-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU max MHz: 3700.0000
CPU min MHz: 2200.0000
BogoMIPS: 7400.38
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization: AMD-V
L1d cache: 1 MiB (32 instances)
L1i cache: 1 MiB (32 instances)
L2 cache: 16 MiB (32 instances)
L3 cache: 128 MiB (8 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] onnx==1.15.0
[pip3] onnxruntime==1.17.0
[pip3] onnxscript==0.1.0.dev20240117
[pip3] optree==0.11.0
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.3.0a0+gitf9f602f
[pip3] torchvision==0.16.1+cf89794
[conda] libmagma 2.7.2 h173bb3b_0 conda-forge
[conda] libmagma_sparse 2.7.2 h173bb3b_0 conda-forge
[conda] libopenvino-pytorch-frontend 2023.1.0 h59595ed_1 conda-forge
[conda] magma 2.7.2 h51420fd_0 conda-forge
[conda] mkl 2022.2.1 h84fe81f_16997 conda-forge
[conda] mkl-include 2023.2.0 h84fe81f_50495 conda-forge
[conda] numpy 1.24.3 pypi_0 pypi
[conda] optree 0.11.0 pypi_0 pypi
[conda] pytorch-triton 3.0.0+989adb9a29 pypi_0 pypi
[conda] torch 2.4.0a0+giteae45c6 dev_0 <develop>
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchvision 0.16.1 cpu_py38h901811f_1 conda-forge
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
Metadata
Metadata
Assignees
Labels
dynamo-triage-june2024dynamo-variable-trackerinternal ramp-up taskTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module