-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[fp8 rowwise] Retune the tile heuristics to increase perf #134781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134781
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 05c422b with merge base 1f1e2ee ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
drisspg
approved these changes
Sep 3, 2024
eqy
approved these changes
Sep 3, 2024
Contributor
Author
|
@pytorchbot merge |
Collaborator
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Chao1Han
pushed a commit
to Chao1Han/pytorch
that referenced
this pull request
Sep 20, 2024
…4781) I propose a new heuristic function to select tile tile size, cluster size, and transposition given M, N and K. It improves the performance across the board (on average) while remaining simple and relying only on a handful of kernels (to limit build time and binary size). Across the shapes I benchmarked, the new heuristic gives a (geometric) mean speedup of +16.5%. Some shapes worsen, but 98.6% of the shapes retain their old performance (up to 5% to allow for noise) or improve it.  I benchmarked on over 5.4k different shapes: - For M and N I swept across all values which are the sums of two powers of 2 (limited to multiples of 64, capped at 16,384) - For K I only used powers of 2 between 1,024 and 8,192 (based on the intuition that the optimal config doesn't depend on K, which turned out to be the case) Here's the detailed speedup for each shape  <details> <summary> This is the code I used to benchmark </summary> ``` import torch import torch.utils.benchmark s = set() for i in range(6, 15): s.add(2**i) for j in range(6, i): s.add(2**i + 2**j) ms = [i for i in sorted(s) if i <= 2**14] ns = [i for i in sorted(s) if i <= 2**14] ks = [2**i for i in range(10, 14)] def make_graph(n_iters, f): g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): for _ in range(n_iters): f() return g def rowwise_scale(t, dtype_t): min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max scale_t = torch.clamp(t.abs().amax(dim=-1, keepdim=True).float(), min=1e-12) / max_v t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t) return t_fp8, scale_t for m in ms: for n in ns: for k in ks: a = torch.randn((m, k), device="cuda", dtype=torch.float) b_t = torch.randn((n, k), device="cuda", dtype=torch.float) a_fp8, scale_a = rowwise_scale(a, torch.float8_e4m3fn) b_t_fp8, scale_b_t = rowwise_scale(b_t, torch.float8_e4m3fn) func = lambda: torch._scaled_mm( a_fp8, b_t_fp8.t(), scale_a=scale_a, scale_b=scale_b_t.t(), bias=None, use_fast_accum=True, out_dtype=torch.bfloat16 ) print(f"{m=},{n=},{k=}") print(torch.utils.benchmark.Timer("g.replay()", globals={"g": make_graph(1000, func)}).blocked_autorange(min_run_time=1).mean / 1000) ``` </details> <details> <summary> This is the code I used for the plots </summary> ``` from itertools import islice import pandas as pd import matplotlib.pyplot as plt from matplotlib.cm import ScalarMappable from matplotlib.colors import FuncNorm from mpl_toolkits.axes_grid1 import ImageGrid def batched(iterable, n): iterator = iter(iterable) while batch := tuple(islice(iterator, n)): yield batch def try_to_convert(v): if v == "False": return False if v == "True": return True return int(v) def get_from_paste(filename): text = open(filename, "rt").read() headers = [] data = [] for config, value in batched(text.splitlines(), 2): config_elems = config.split(",") if not headers: headers = [e.partition("=")[0] for e in config_elems] data.append((*(try_to_convert(e.partition("=")[-1]) for e in config_elems), float(value))) return pd.DataFrame(data, columns=headers + ["latency"]) old_latencies = get_from_paste(...) new_latencies = get_from_paste(...) ratios = pd.merge(new_latencies, old_latencies, how="left", left_on=["m", "n", "k"], right_on=["m", "n", "k"], suffixes=("_new", "_old")) ratios = ratios.assign(ratio=ratios.latency_old / ratios.latency_new) fig = plt.figure(figsize=(40.0, 10.0)) grid = ImageGrid( fig, 111, nrows_ncols=(1, 4), axes_pad=0.5, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="7%", cbar_pad=0.15, ) log_amax = np.max(np.abs(np.log(ratios.ratio.to_numpy()))) for K, ax in zip([1024, 2048, 4096, 8192], grid): pivoted = ratios[(ratios.k == K)].pivot_table(index="m", columns="n", values="ratio") im = ax.imshow(np.log(pivoted.to_numpy()), origin="lower", vmin=-log_amax, vmax=log_amax, cmap="PiYG") m_vals, n_vals = pivoted.axes ax.set_xticks(np.arange(len(n_vals)), labels=[f"N={i}" for i in n_vals.values], fontsize=12) ax.set_yticks(np.arange(len(m_vals)), labels=[f"M={i}" for i in m_vals.values], fontsize=12) plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor") ax.grid(False) ax.set_title(f"K={K}", fontsize=20) norm = FuncNorm((lambda x: np.log(x), lambda x: np.exp(x)), np.exp(-log_amax), np.exp(log_amax)) ax.cax.colorbar(ScalarMappable(norm=norm, cmap="PiYG")) plt.show() counts, bins = np.histogram(np.log(ratios.ratio.to_numpy()), bins=500) plt.stairs(counts, np.exp(bins), fill=True) plt.xscale("function", functions=(lambda x: np.log(x), lambda x: np.exp(x))) ``` </details> I only benchmarked fast_accum=True and out_dtype=torch.bfloat16 supposing that these are the most commonly-used flags (e.g., with fast_accum=False row-wise scaling is much slower than tensor-wise scaling hence unpractical). Pull Request resolved: pytorch#134781 Approved by: https://github.com/drisspg, https://github.com/eqy ghstack dependencies: pytorch#134773
pytorchmergebot
pushed a commit
that referenced
this pull request
Jul 11, 2025
Most of the work had already been done by @jeffdaily in #154680, but there was one remaining check that needed to be modified in order for `torch._scaled_mm` to use cuBLAS over CUTLASS when available. I tested this change by rebuilding PyTorch locally with CUDA 12.9 and ran `torch._scaled_mm` under the profiler, and observed that the kernel being launched is called `nvjet_qqtst_128x128_128x6_1x1_h_bz_coopA_algo2_ovscale_TNT` (where `ovscale` stands for "outer vector scaling", I believe, which is how cuBLAS calls this scaling mode). I then benchmarked the new kernels against the old CUTLASS ones on a standard 700W H100 GPU. I used the same approach as in #134781, and obtained these speed-ups:   We see that the two kernels perform very closely (I'm surprised, I would have expected cuBLAS to outperform CUTLASS across the board), with some thin/skewed shapes becoming worse but some very large shapes becoming better. I guess the questions are whether we consider this a net-zero change (given that there's improvements _and_ degradations), and how large we consider the burden of maintaining our own CUTLASS kernels. Pull Request resolved: #157905 Approved by: https://github.com/eqy, https://github.com/Skylion007, https://github.com/drisspg
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
ciflow/trunk
Trigger trunk jobs on your pull request
Merged
release notes: cuda
release notes category
topic: not user facing
topic category
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
I propose a new heuristic function to select tile tile size, cluster size, and transposition given M, N and K. It improves the performance across the board (on average) while remaining simple and relying only on a handful of kernels (to limit build time and binary size).
Across the shapes I benchmarked, the new heuristic gives a (geometric) mean speedup of +16.5%. Some shapes worsen, but 98.6% of the shapes retain their old performance (up to 5% to allow for noise) or improve it.

I benchmarked on over 5.4k different shapes:
Here's the detailed speedup for each shape

This is the code I used to benchmark
This is the code I used for the plots
I only benchmarked fast_accum=True and out_dtype=torch.bfloat16 supposing that these are the most commonly-used flags (e.g., with fast_accum=False row-wise scaling is much slower than tensor-wise scaling hence unpractical).