-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
Context:
While using float8 training, the operators of fp8 = cast_to_fp8(input_tensor); fp8_t = fp8.t().contiguous().t() are used in every float8 linear. We use Inductor to generate the performant fused kernel for these ops. However, there are some inefficiencies in the generated kernels.
Examples
Following is an example of the cast+transpose pattern.
def test() -> None:
ref_dtype = torch.bfloat16
M, K, N = 4096, 4096, 3072
input_tensor = torch.randn(M, K, device="cuda", dtype=ref_dtype, requires_grad=False)
scale = torch.Tensor([10.0]).to("cuda")
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
def test_pattern1(tensor_x_inp, scale_x):
tensor_x = tensor_x_inp * scale_x
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
tensor_fp8 = tensor_x.to(torch.float8_e4m3fn)
tensor_fp8_t = tensor_fp8.t().contiguous().t()
return (tensor_fp8, tensor_fp8_t)
test_pattern = torch.compile(test_pattern1)
tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale)
print(tensor_fp8.stride(), tensor_fp8_t.stride())
# TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_OUTPUT=/tmp/profile.txt TORCH_LOGS="fusion, +inductor,+schedule,output_code" TORCH_COMPILE_DEBUG=1 python test.py
The issue in "test_pattern1"
In test_pattern1, tensor_fp8.t().contiguous().t() is not fused with the casting kernel. Instead, Inductor generated two kernels (triton_poi_fused__to_copy_clamp_mul_0, and triton_poi_fused_clone_1).
From Inductor log,
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:1842] [0/0] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:2084] [0/0] [__fusion] fuse_nodes_once, candidates:
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:2086] [0/0] [__fusion] SchedulerNode(name='buf0'), Pointwise(['[4096, 4096]', 'origins={mul, convert_element_type, clamp_max, clamp_min}'])
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:2086] [0/0] [__fusion] SchedulerNode(name='buf1'), Pointwise(['[4096, 4096]', 'origins={clone}'])
V0702 16:39:45.362000 139774580315136 torch/_inductor/scheduler.py:645] [0/0] [__fusion] cannot fuse buf0 with buf1: no shared data
Alternatives
For a short-term fix, we added a handwritten kernel for the fused cast+transpose, and use Inductor pattern matcher to replace the ops with the handwritten kernel. But it would be nice to have a general fix in Inductor.
Additional context
No response
cc @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @yanbing-j @vkuzo