Skip to content

Improve inductor codegen for writing out tensor and tensor.t() in the same kernel #133242

@vkuzo

Description

@vkuzo

🐛 Describe the bug

🚀 The feature, motivation and pitch

This is moving "issue 2" from #130015 to be tracked separately.

Context:
While using float8 training, the operators of fp8 = cast_to_fp8(input_tensor); fp8_t = fp8.t().contiguous().t() are used in every float8 linear. We use Inductor to generate the performant fused kernel for these ops. However, there are some inefficiencies in the generated kernels.

Examples
Following is an example of the cast+transpose pattern

def test() -> None:
    ref_dtype = torch.bfloat16
    M, K, N = 4096, 4096, 3072
    
    input_tensor = torch.randn(M, K, device="cuda", dtype=ref_dtype, requires_grad=False)
    scale = torch.Tensor([10.0]).to("cuda")

    E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
    E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max

    def test_pattern2(tensor_x_inp, scale_x):
        tensor_x = tensor_x_inp * scale_x
        tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
        tensor_fp8 = tensor_x.to(torch.float8_e4m3fn)

        tensor_x_t = (tensor_x_inp * scale_x).t()
        tensor_x_t = tensor_x_t.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
        tensor_fp8_t = tensor_x_t.to(torch.float8_e4m3fn)
        
        tensor_fp8_t = tensor_fp8_t.contiguous().t()

        return (tensor_fp8, tensor_fp8_t)

    test_pattern = torch.compile(test_pattern2)
    tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale)
    print(tensor_fp8.stride(), tensor_fp8_t.stride())


# TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_OUTPUT=/tmp/profile.txt  TORCH_LOGS="fusion, +inductor,+schedule,output_code" TORCH_COMPILE_DEBUG=1 python test.py

The issue in "test_pattern2"
If we use test_pattern2, only one fused kernel (triton_poi_fused__to_copy_clamp_clone_mul_0) is generated. However, the bandwidth of the fused kernel is only 1.5 TB/s (less than the expected 2 TB/s on an H100).
Looking at the generated kernel,

V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] @triton_heuristics.pointwise(
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     size_hints=[4096, 4096], tile_hint=TileHint.DEFAULT,
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     filename=__file__,
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     triton_meta={'signature': {0: '*bf16', 1: '*fp32', 2: '*fp8e4nv', 3: '*fp8e4nv', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=132), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clamp_clone_mul_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '111C3DDC11B3C29F74BF1795749D65D0F93E66347392E755AD82B8A40774FD60', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'profile_bandwidth': True, 'profile_bandwidth_regex': '', 'profile_bandwidth_output': '/tmp/profile.txt', 'kernel_num_gb': 0.067108868},
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     min_elem_per_thread=2
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] )
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] @triton.jit
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     ynumel = 4096
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     xnumel = 4096
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     ymask = tl.full([XBLOCK, YBLOCK], True, tl.int1)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     xmask = tl.full([XBLOCK, YBLOCK], True, tl.int1)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     x1 = xindex
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     y0 = yindex
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp0 = tl.load(in_ptr0 + (x1 + (4096*y0)), None).to(tl.float32)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp2 = tl.load(in_ptr1 + (0))
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp3 = tl.broadcast_to(tmp2, [XBLOCK, YBLOCK])
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp10 = tl.load(in_ptr0 + (y0 + (4096*x1)), None).to(tl.float32)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp1 = tmp0.to(tl.float32)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp4 = tmp1 * tmp3
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp5 = -448.0
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp6 = triton_helpers.maximum(tmp4, tmp5)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp7 = 448.0
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp8 = triton_helpers.minimum(tmp6, tmp7)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp9 = tmp8.to(tl.float8e4nv)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp11 = tmp10.to(tl.float32)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp12 = tmp11 * tmp3
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp13 = triton_helpers.maximum(tmp12, tmp5)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp14 = triton_helpers.minimum(tmp13, tmp7)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tmp15 = tmp14.to(tl.float8e4nv)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tl.store(out_ptr0 + (x1 + (4096*y0)), tmp9, None)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code]     tl.store(out_ptr1 + (x1 + (4096*y0)), tmp15, None)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] ''', device_str='cuda')

One simple change that can be applied to the kernel is removing the load of tmp10, and write casted tmp0 to the transposed position of out_ptr1.

Versions

main branch as of 2024-08-09

cc @msaroufim @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @yanbing-j @albanD

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesmodule: inductormodule: performanceIssues related to performance, either of kernel code or framework glueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions