Skip to content

Improve the Inductor generated kernel for the pattern of output1 = pointwise(intput); output2 = transpose(output1) #130015

@y-sq

Description

@y-sq

🚀 The feature, motivation and pitch

Context:
While using float8 training, the operators of fp8 = cast_to_fp8(input_tensor); fp8_t = fp8.t().contiguous().t() are used in every float8 linear. We use Inductor to generate the performant fused kernel for these ops. However, there are some inefficiencies in the generated kernels.

Examples
Following is an example of the cast+transpose pattern.

def test() -> None:
    ref_dtype = torch.bfloat16
    M, K, N = 4096, 4096, 3072
    
    input_tensor = torch.randn(M, K, device="cuda", dtype=ref_dtype, requires_grad=False)
    scale = torch.Tensor([10.0]).to("cuda")

    E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
    E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max

    def test_pattern1(tensor_x_inp, scale_x):
        tensor_x = tensor_x_inp * scale_x
        tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
        tensor_fp8 = tensor_x.to(torch.float8_e4m3fn)

        tensor_fp8_t = tensor_fp8.t().contiguous().t()

        return (tensor_fp8, tensor_fp8_t)

    test_pattern = torch.compile(test_pattern1)
    tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale)
    print(tensor_fp8.stride(), tensor_fp8_t.stride())

# TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_OUTPUT=/tmp/profile.txt  TORCH_LOGS="fusion, +inductor,+schedule,output_code" TORCH_COMPILE_DEBUG=1 python test.py

The issue in "test_pattern1"
In test_pattern1, tensor_fp8.t().contiguous().t() is not fused with the casting kernel. Instead, Inductor generated two kernels (triton_poi_fused__to_copy_clamp_mul_0, and triton_poi_fused_clone_1).

From Inductor log,

V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:1842] [0/0] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:2084] [0/0] [__fusion] fuse_nodes_once, candidates:
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:2086] [0/0] [__fusion]   SchedulerNode(name='buf0'), Pointwise(['[4096, 4096]', 'origins={mul, convert_element_type, clamp_max, clamp_min}'])
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:2086] [0/0] [__fusion]   SchedulerNode(name='buf1'), Pointwise(['[4096, 4096]', 'origins={clone}'])
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:645] [0/0] [__fusion] cannot fuse buf0 with buf1: no shared data

Alternatives

For a short-term fix, we added a handwritten kernel for the fused cast+transpose, and use Inductor pattern matcher to replace the ops with the handwritten kernel. But it would be nice to have a general fix in Inductor.

Additional context

No response

cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @yanbing-j @vkuzo

Metadata

Metadata

Assignees

Labels

module: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions