Skip to content

Commit e184391

Browse files
mengluy0125pytorchmergebot
authored andcommitted
[PT2][Inductor][Optmus] fix test_pad_mm_bf16 and reland to fix long computation kernel (#136349)
Summary: see D62220158 Test Plan: ``` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:pad_mm -- --exact 'caffe2/test/inductor:pad_mm - test_pad_mm_bf16 (caffe2.test.inductor.test_pad_mm.PadMMTest)' --run-disabled ``` ### H100 Buck UI: https://www.internalfb.com/buck2/e5d85802-cab7-41a5-aacc-95f541796a99 Test UI: https://www.internalfb.com/intern/testinfra/testrun/9570149258587374 Network: Up: 9.1KiB Down: 0B (reSessionID-b339b51b-6a0e-4347-9414-1ba38f26a5d0) Jobs completed: 9. Time elapsed: 1:15.7s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 1. Fail 0. Fatal 0. Skip 1. Build failure 0 ### A100 Buck UI: https://www.internalfb.com/buck2/1082ad6e-56b0-4eb5-8092-ce507ca9a70e Test UI: https://www.internalfb.com/intern/testinfra/testrun/8444249533824784 Network: Up: 9.2KiB Down: 0B (reSessionID-2b3056ac-f29e-4de4-b6f5-9d994acf566b) Jobs completed: 9. Time elapsed: 1:36.9s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E see D62220158 Differential Revision: D63040455 Pull Request resolved: #136349 Approved by: https://github.com/dshi7
1 parent 0287146 commit e184391

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

torch/_inductor/fx_passes/pad_mm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,23 @@ def should_pad(key: str, ori_time, pad_time) -> bool:
364364
return should_pad
365365

366366

367+
def should_pad_mm_bf16(dtype, M, N, K):
368+
# always force pad for mm with bf16 when the following are satisfied to avoid perf regression
369+
large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[
370+
"pad_aten_mm_pass"
371+
].get("k_threshold_to_pad", 8388608)
372+
if (
373+
dtype is torch.bfloat16
374+
and K > M
375+
and K > N
376+
and N % 2 == 1
377+
and K >= large_k_threshold_to_pad
378+
and torch.cuda.get_device_capability() < (9, 0)
379+
): # doesnt repro on h100s:
380+
return True
381+
return False
382+
383+
367384
def should_pad_bench(
368385
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
369386
) -> bool:
@@ -410,6 +427,12 @@ def realize_symbols(ds):
410427
if torch._inductor.config.force_shape_pad:
411428
return True
412429

430+
if (
431+
"pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options
432+
and should_pad_mm_bf16(mat1.dtype, m, n, k)
433+
):
434+
return True
435+
413436
if not has_triton():
414437
return False
415438

torch/_inductor/fx_passes/split_cat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"decompose_mm_pass",
6666
"unbind_stack_aten_pass",
6767
"shape_padding_multiplier",
68+
"pad_aten_mm_pass",
6869
]
6970

7071
for pass_name in pre_grad_pass_names:

0 commit comments

Comments
 (0)