-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
🚀 The feature, motivation and pitch
This is moving "issue 2" from #130015 to be tracked separately.
Context:
While using float8 training, the operators of fp8 = cast_to_fp8(input_tensor); fp8_t = fp8.t().contiguous().t() are used in every float8 linear. We use Inductor to generate the performant fused kernel for these ops. However, there are some inefficiencies in the generated kernels.
Examples
Following is an example of the cast+transpose pattern
def test() -> None:
ref_dtype = torch.bfloat16
M, K, N = 4096, 4096, 3072
input_tensor = torch.randn(M, K, device="cuda", dtype=ref_dtype, requires_grad=False)
scale = torch.Tensor([10.0]).to("cuda")
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
def test_pattern2(tensor_x_inp, scale_x):
tensor_x = tensor_x_inp * scale_x
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
tensor_fp8 = tensor_x.to(torch.float8_e4m3fn)
tensor_x_t = (tensor_x_inp * scale_x).t()
tensor_x_t = tensor_x_t.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
tensor_fp8_t = tensor_x_t.to(torch.float8_e4m3fn)
tensor_fp8_t = tensor_fp8_t.contiguous().t()
return (tensor_fp8, tensor_fp8_t)
test_pattern = torch.compile(test_pattern2)
tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale)
print(tensor_fp8.stride(), tensor_fp8_t.stride())
# TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_OUTPUT=/tmp/profile.txt TORCH_LOGS="fusion, +inductor,+schedule,output_code" TORCH_COMPILE_DEBUG=1 python test.py
The issue in "test_pattern2"
If we use test_pattern2, only one fused kernel (triton_poi_fused__to_copy_clamp_clone_mul_0) is generated. However, the bandwidth of the fused kernel is only 1.5 TB/s (less than the expected 2 TB/s on an H100).
Looking at the generated kernel,
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] @triton_heuristics.pointwise(
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] size_hints=[4096, 4096], tile_hint=TileHint.DEFAULT,
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] filename=__file__,
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] triton_meta={'signature': {0: '*bf16', 1: '*fp32', 2: '*fp8e4nv', 3: '*fp8e4nv', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=132), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clamp_clone_mul_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '111C3DDC11B3C29F74BF1795749D65D0F93E66347392E755AD82B8A40774FD60', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'profile_bandwidth': True, 'profile_bandwidth_regex': '', 'profile_bandwidth_output': '/tmp/profile.txt', 'kernel_num_gb': 0.067108868},
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] min_elem_per_thread=2
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] )
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] @triton.jit
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] ynumel = 4096
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] xnumel = 4096
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] ymask = tl.full([XBLOCK, YBLOCK], True, tl.int1)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] xoffset = tl.program_id(0) * XBLOCK
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] xmask = tl.full([XBLOCK, YBLOCK], True, tl.int1)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] x1 = xindex
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] y0 = yindex
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp0 = tl.load(in_ptr0 + (x1 + (4096*y0)), None).to(tl.float32)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp2 = tl.load(in_ptr1 + (0))
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp3 = tl.broadcast_to(tmp2, [XBLOCK, YBLOCK])
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp10 = tl.load(in_ptr0 + (y0 + (4096*x1)), None).to(tl.float32)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp1 = tmp0.to(tl.float32)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp4 = tmp1 * tmp3
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp5 = -448.0
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp6 = triton_helpers.maximum(tmp4, tmp5)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp7 = 448.0
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp8 = triton_helpers.minimum(tmp6, tmp7)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp9 = tmp8.to(tl.float8e4nv)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp11 = tmp10.to(tl.float32)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp12 = tmp11 * tmp3
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp13 = triton_helpers.maximum(tmp12, tmp5)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp14 = triton_helpers.minimum(tmp13, tmp7)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tmp15 = tmp14.to(tl.float8e4nv)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tl.store(out_ptr0 + (x1 + (4096*y0)), tmp9, None)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] tl.store(out_ptr1 + (x1 + (4096*y0)), tmp15, None)
V0702 16:39:46.770000 139774580315136 torch/_inductor/graph.py:1680] [1/0] [__output_code] ''', device_str='cuda')
One simple change that can be applied to the kernel is removing the load of tmp10, and write casted tmp0 to the transposed position of out_ptr1.
Versions
main branch as of 2024-08-09
cc @msaroufim @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @yanbing-j @albanD