-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Prototype] Adding lowering to persistent-tma device kernel for _scaled_mm #138536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ed_mm [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138536
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 51bbc38 with merge base be90d3c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…l for _scaled_mm"
Code I am using to iterate w/
```Python
import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path
from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
addmm_float8_unwrapped_inference,
preprocess_data,
Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_device_tma_persistent,
)
from enum import Enum
logging.getLogger("transformer_nuggets").setLevel(logging.INFO)
class FP8Kernel(Enum):
PERSISTENT = "Persistent"
PERSISTENT_TMA = "Persistent-TMA"
DEVICE_TMA = "Device-TMA"
SCALED_MM = "Scaled-MM"
class ScalingStrategy(Enum):
PER_TENSOR = "PerTensor"
PER_ROW = "PerRow"
dataclass(frozen=True)
class ExperimentConfig:
M: int
K: int
N: int
scaling_strategy: ScalingStrategy
fp8_kernel: FP8Kernel
compile: bool
def get_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
scaling_strategy: ScalingStrategy,
fp8_kernel: FP8Kernel,
):
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
if scaling_strategy == ScalingStrategy.PER_TENSOR:
a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
elif scaling_strategy == ScalingStrategy.PER_ROW:
a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
else:
raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")
assert fp8_kernel == FP8Kernel.SCALED_MM
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
)
def run_matmul(config: ExperimentConfig):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)
fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)
if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")
_ = fp8_matmul()
return
def main():
torch.random.manual_seed(123)
# Define your experiment configuration here
config = ExperimentConfig(
M=8192,
K=8192,
N=8192,
scaling_strategy=ScalingStrategy.PER_TENSOR,
fp8_kernel=FP8Kernel.SCALED_MM,
compile=True,
)
run_matmul(config)
if __name__ == "__main__":
CLI(main)
```
Generating the following output code: https://gist.github.com/drisspg/5083632ae57e268a43100555d3890a19
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov
[ghstack-poisoned]
…l for _scaled_mm"
Code I am using to iterate w/
```Python
import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path
from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
addmm_float8_unwrapped_inference,
preprocess_data,
Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_device_tma_persistent,
)
from enum import Enum
logging.getLogger("transformer_nuggets").setLevel(logging.INFO)
class FP8Kernel(Enum):
PERSISTENT = "Persistent"
PERSISTENT_TMA = "Persistent-TMA"
DEVICE_TMA = "Device-TMA"
SCALED_MM = "Scaled-MM"
class ScalingStrategy(Enum):
PER_TENSOR = "PerTensor"
PER_ROW = "PerRow"
dataclass(frozen=True)
class ExperimentConfig:
M: int
K: int
N: int
scaling_strategy: ScalingStrategy
fp8_kernel: FP8Kernel
compile: bool
def get_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
scaling_strategy: ScalingStrategy,
fp8_kernel: FP8Kernel,
):
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
if scaling_strategy == ScalingStrategy.PER_TENSOR:
a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
elif scaling_strategy == ScalingStrategy.PER_ROW:
a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
else:
raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")
assert fp8_kernel == FP8Kernel.SCALED_MM
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
)
def run_matmul(config: ExperimentConfig):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)
fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)
if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")
_ = fp8_matmul()
return
def main():
torch.random.manual_seed(123)
# Define your experiment configuration here
config = ExperimentConfig(
M=8192,
K=8192,
N=8192,
scaling_strategy=ScalingStrategy.PER_TENSOR,
fp8_kernel=FP8Kernel.SCALED_MM,
compile=True,
)
run_matmul(config)
if __name__ == "__main__":
CLI(main)
```
Generating the following output code: https://gist.github.com/drisspg/5083632ae57e268a43100555d3890a19
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov
[ghstack-poisoned]
torch/_inductor/kernel/mm_scaled.py
Outdated
|
|
||
| import torch | ||
| from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate | ||
| from torch.utils._triton import has_triton_tma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg this one checks for existence of the host-side TMA API (i.e., create_1d_tma_descriptor and create_2d_tma_descriptor). Here you need device-side TMA API which seems to have been added in a later PR triton-lang/triton#4633. I guess, we need another helper function checking if experimental_device_tensormap_create1d and experimental_device_tensormap_create2d are available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do
torch/_inductor/kernel/mm_scaled.py
Outdated
| f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." | ||
| ) | ||
| NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count | ||
| tma_size = 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unused.
|
This has been my main benchmarking script: drisspg/transformer_nuggets#39 I tried this and for some reason on the second iteration I always run into: If I set dynamic=False, all works, so clearly there is some dyanmic shape bug getting in here. |
|
Without Tma store: |
|
So there seems to be up to 3% perf improvement from TMA store vs |
|
Could you also share the baseline CSV (w/o any persistent+TMA)? |
|
@aakhundov definitely much improved results: https://gist.github.com/drisspg/698ea0a7eac8ae4542350f939f05f9d1 |
|
Updated and ran the existing tests getting: Thread 0x00007f9f8e7fc640 (most recent call first): I am cc @aakhundov |
|
Closing cause this got stale and ghstack yelled at me: #142045 |
Stack from ghstack (oldest at bottom):
Code I am using to iterate w/
Generating the following output code: https://gist.github.com/drisspg/5083632ae57e268a43100555d3890a19
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov