Skip to content

Conversation

@lw
Copy link
Contributor

@lw lw commented Aug 29, 2024

Stack from ghstack (oldest at bottom):

I propose a new heuristic function to select tile tile size, cluster size, and transposition given M, N and K. It improves the performance across the board (on average) while remaining simple and relying only on a handful of kernels (to limit build time and binary size).

Across the shapes I benchmarked, the new heuristic gives a (geometric) mean speedup of +16.5%. Some shapes worsen, but 98.6% of the shapes retain their old performance (up to 5% to allow for noise) or improve it.
image

I benchmarked on over 5.4k different shapes:

  • For M and N I swept across all values which are the sums of two powers of 2 (limited to multiples of 64, capped at 16,384)
  • For K I only used powers of 2 between 1,024 and 8,192 (based on the intuition that the optimal config doesn't depend on K, which turned out to be the case)

Here's the detailed speedup for each shape
image

This is the code I used to benchmark
import torch
import torch.utils.benchmark

s = set()

for i in range(6, 15):
    s.add(2**i)
    for j in range(6, i):
        s.add(2**i + 2**j)

ms = [i for i in sorted(s) if i <= 2**14]
ns = [i for i in sorted(s) if i <= 2**14]
ks = [2**i for i in range(10, 14)]

def make_graph(n_iters, f):
    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        for _ in range(n_iters):
            f()
    return g

def rowwise_scale(t, dtype_t):
    min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
    scale_t = torch.clamp(t.abs().amax(dim=-1, keepdim=True).float(), min=1e-12) / max_v
    t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
    return t_fp8, scale_t

for m in ms:
    for n in ns:
        for k in ks:
            a = torch.randn((m, k), device="cuda", dtype=torch.float)
            b_t = torch.randn((n, k), device="cuda", dtype=torch.float)
            a_fp8, scale_a = rowwise_scale(a, torch.float8_e4m3fn)
            b_t_fp8, scale_b_t = rowwise_scale(b_t, torch.float8_e4m3fn)
            func = lambda: torch._scaled_mm(
                a_fp8,
                b_t_fp8.t(),
                scale_a=scale_a,
                scale_b=scale_b_t.t(),
                bias=None,
                use_fast_accum=True,
                out_dtype=torch.bfloat16
            )
            print(f"{m=},{n=},{k=}")
            print(torch.utils.benchmark.Timer("g.replay()", globals={"g": make_graph(1000, func)}).blocked_autorange(min_run_time=1).mean / 1000)
This is the code I used for the plots
from itertools import islice

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import FuncNorm
from mpl_toolkits.axes_grid1 import ImageGrid


def batched(iterable, n):
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        yield batch

def try_to_convert(v):
    if v == "False":
        return False
    if v == "True":
        return True
    return int(v)

def get_from_paste(filename):
    text = open(filename, "rt").read()
    headers = []
    data = []
    for config, value in batched(text.splitlines(), 2):
        config_elems = config.split(",")
        if not headers:
            headers = [e.partition("=")[0] for e in config_elems]
        data.append((*(try_to_convert(e.partition("=")[-1]) for e in config_elems), float(value)))
    return pd.DataFrame(data, columns=headers + ["latency"])

old_latencies = get_from_paste(...)
new_latencies = get_from_paste(...)

ratios = pd.merge(new_latencies, old_latencies, how="left", left_on=["m", "n", "k"], right_on=["m", "n", "k"], suffixes=("_new", "_old"))
ratios = ratios.assign(ratio=ratios.latency_old / ratios.latency_new)

fig = plt.figure(figsize=(40.0, 10.0))
grid = ImageGrid(
    fig,
    111,
    nrows_ncols=(1, 4),
    axes_pad=0.5,
    share_all=True,
    cbar_location="right",
    cbar_mode="single",
    cbar_size="7%",
    cbar_pad=0.15,
)

log_amax = np.max(np.abs(np.log(ratios.ratio.to_numpy())))

for K, ax in zip([1024, 2048, 4096, 8192], grid):
    pivoted = ratios[(ratios.k == K)].pivot_table(index="m", columns="n", values="ratio")
    im = ax.imshow(np.log(pivoted.to_numpy()), origin="lower", vmin=-log_amax, vmax=log_amax, cmap="PiYG")
    m_vals, n_vals = pivoted.axes
    ax.set_xticks(np.arange(len(n_vals)), labels=[f"N={i}" for i in n_vals.values], fontsize=12)
    ax.set_yticks(np.arange(len(m_vals)), labels=[f"M={i}" for i in m_vals.values], fontsize=12)
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")
    ax.grid(False)
    ax.set_title(f"K={K}", fontsize=20)

norm = FuncNorm((lambda x: np.log(x), lambda x: np.exp(x)), np.exp(-log_amax), np.exp(log_amax))
ax.cax.colorbar(ScalarMappable(norm=norm, cmap="PiYG"))
plt.show()

counts, bins = np.histogram(np.log(ratios.ratio.to_numpy()), bins=500)
plt.stairs(counts, np.exp(bins), fill=True)
plt.xscale("function", functions=(lambda x: np.log(x), lambda x: np.exp(x)))

I only benchmarked fast_accum=True and out_dtype=torch.bfloat16 supposing that these are the most commonly-used flags (e.g., with fast_accum=False row-wise scaling is much slower than tensor-wise scaling hence unpractical).

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134781

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 05c422b with merge base 1f1e2ee (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Aug 29, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
lw added a commit that referenced this pull request Aug 30, 2024
@lw lw added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Aug 30, 2024
@lw lw marked this pull request as ready for review August 30, 2024 13:40
@lw lw requested review from eqy and syed-ahmed as code owners August 30, 2024 13:40
@lw
Copy link
Contributor Author

lw commented Sep 4, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…4781)

I propose a new heuristic function to select tile tile size, cluster size, and transposition given M, N and K. It improves the performance across the board (on average) while remaining simple and relying only on a handful of kernels (to limit build time and binary size).

Across the shapes I benchmarked, the new heuristic gives a (geometric) mean speedup of +16.5%. Some shapes worsen, but 98.6% of the shapes retain their old performance (up to 5% to allow for noise) or improve it.
![image](https://github.com/user-attachments/assets/bca30583-ac32-4af6-a4f9-37164bdb2430)

I benchmarked on over 5.4k different shapes:
- For M and N I swept across all values which are the sums of two powers of 2 (limited to multiples of 64, capped at 16,384)
- For K I only used powers of 2 between 1,024 and 8,192 (based on the intuition that the optimal config doesn't depend on K, which turned out to be the case)

Here's the detailed speedup for each shape
![image](https://github.com/user-attachments/assets/acac4318-9ee0-455d-861b-c764b8c13d22)

<details>
<summary>
This is the code I used to benchmark
</summary>

```
import torch
import torch.utils.benchmark

s = set()

for i in range(6, 15):
    s.add(2**i)
    for j in range(6, i):
        s.add(2**i + 2**j)

ms = [i for i in sorted(s) if i <= 2**14]
ns = [i for i in sorted(s) if i <= 2**14]
ks = [2**i for i in range(10, 14)]

def make_graph(n_iters, f):
    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        for _ in range(n_iters):
            f()
    return g

def rowwise_scale(t, dtype_t):
    min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
    scale_t = torch.clamp(t.abs().amax(dim=-1, keepdim=True).float(), min=1e-12) / max_v
    t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
    return t_fp8, scale_t

for m in ms:
    for n in ns:
        for k in ks:
            a = torch.randn((m, k), device="cuda", dtype=torch.float)
            b_t = torch.randn((n, k), device="cuda", dtype=torch.float)
            a_fp8, scale_a = rowwise_scale(a, torch.float8_e4m3fn)
            b_t_fp8, scale_b_t = rowwise_scale(b_t, torch.float8_e4m3fn)
            func = lambda: torch._scaled_mm(
                a_fp8,
                b_t_fp8.t(),
                scale_a=scale_a,
                scale_b=scale_b_t.t(),
                bias=None,
                use_fast_accum=True,
                out_dtype=torch.bfloat16
            )
            print(f"{m=},{n=},{k=}")
            print(torch.utils.benchmark.Timer("g.replay()", globals={"g": make_graph(1000, func)}).blocked_autorange(min_run_time=1).mean / 1000)
```
</details>

<details>
<summary>
This is the code I used for the plots
</summary>

```
from itertools import islice

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import FuncNorm
from mpl_toolkits.axes_grid1 import ImageGrid

def batched(iterable, n):
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        yield batch

def try_to_convert(v):
    if v == "False":
        return False
    if v == "True":
        return True
    return int(v)

def get_from_paste(filename):
    text = open(filename, "rt").read()
    headers = []
    data = []
    for config, value in batched(text.splitlines(), 2):
        config_elems = config.split(",")
        if not headers:
            headers = [e.partition("=")[0] for e in config_elems]
        data.append((*(try_to_convert(e.partition("=")[-1]) for e in config_elems), float(value)))
    return pd.DataFrame(data, columns=headers + ["latency"])

old_latencies = get_from_paste(...)
new_latencies = get_from_paste(...)

ratios = pd.merge(new_latencies, old_latencies, how="left", left_on=["m", "n", "k"], right_on=["m", "n", "k"], suffixes=("_new", "_old"))
ratios = ratios.assign(ratio=ratios.latency_old / ratios.latency_new)

fig = plt.figure(figsize=(40.0, 10.0))
grid = ImageGrid(
    fig,
    111,
    nrows_ncols=(1, 4),
    axes_pad=0.5,
    share_all=True,
    cbar_location="right",
    cbar_mode="single",
    cbar_size="7%",
    cbar_pad=0.15,
)

log_amax = np.max(np.abs(np.log(ratios.ratio.to_numpy())))

for K, ax in zip([1024, 2048, 4096, 8192], grid):
    pivoted = ratios[(ratios.k == K)].pivot_table(index="m", columns="n", values="ratio")
    im = ax.imshow(np.log(pivoted.to_numpy()), origin="lower", vmin=-log_amax, vmax=log_amax, cmap="PiYG")
    m_vals, n_vals = pivoted.axes
    ax.set_xticks(np.arange(len(n_vals)), labels=[f"N={i}" for i in n_vals.values], fontsize=12)
    ax.set_yticks(np.arange(len(m_vals)), labels=[f"M={i}" for i in m_vals.values], fontsize=12)
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")
    ax.grid(False)
    ax.set_title(f"K={K}", fontsize=20)

norm = FuncNorm((lambda x: np.log(x), lambda x: np.exp(x)), np.exp(-log_amax), np.exp(log_amax))
ax.cax.colorbar(ScalarMappable(norm=norm, cmap="PiYG"))
plt.show()

counts, bins = np.histogram(np.log(ratios.ratio.to_numpy()), bins=500)
plt.stairs(counts, np.exp(bins), fill=True)
plt.xscale("function", functions=(lambda x: np.log(x), lambda x: np.exp(x)))
```
</details>

I only benchmarked fast_accum=True and out_dtype=torch.bfloat16 supposing that these are the most commonly-used flags (e.g., with fast_accum=False row-wise scaling is much slower than tensor-wise scaling hence unpractical).
Pull Request resolved: pytorch#134781
Approved by: https://github.com/drisspg, https://github.com/eqy
ghstack dependencies: pytorch#134773
@github-actions github-actions bot deleted the gh/lw/16/head branch October 5, 2024 02:05
pytorchmergebot pushed a commit that referenced this pull request Jul 11, 2025
Most of the work had already been done by @jeffdaily in #154680, but there was one remaining check that needed to be modified in order for `torch._scaled_mm` to use cuBLAS over CUTLASS when available.

I tested this change by rebuilding PyTorch locally with CUDA 12.9 and ran `torch._scaled_mm` under the profiler, and observed that the kernel being launched is called `nvjet_qqtst_128x128_128x6_1x1_h_bz_coopA_algo2_ovscale_TNT` (where `ovscale` stands for "outer vector scaling", I believe, which is how cuBLAS calls this scaling mode).

I then benchmarked the new kernels against the old CUTLASS ones on a standard 700W H100 GPU. I used the same approach as in #134781, and obtained these speed-ups:
![image](https://github.com/user-attachments/assets/43dfb816-9ccf-40c5-8b2a-571ce9cb511d)
![image](https://github.com/user-attachments/assets/be7ac6f2-e16c-479b-ad5c-f8039caba4b1)

We see that the two kernels perform very closely (I'm surprised, I would have expected cuBLAS to outperform CUTLASS across the board), with some thin/skewed shapes becoming worse but some very large shapes becoming better.

I guess the questions are whether we consider this a net-zero change (given that there's improvements _and_ degradations), and how large we consider the burden of maintaining our own CUTLASS kernels.

Pull Request resolved: #157905
Approved by: https://github.com/eqy, https://github.com/Skylion007, https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants