Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Oct 22, 2024

Stack from ghstack (oldest at bottom):

Code I am using to iterate w/

import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path

from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
    addmm_float8_unwrapped_inference,
    preprocess_data,
    Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
    matmul_persistent,
    matmul_tma_persistent,
    matmul_device_tma_persistent,
)
from enum import Enum

logging.getLogger("transformer_nuggets").setLevel(logging.INFO)


class FP8Kernel(Enum):
    PERSISTENT = "Persistent"
    PERSISTENT_TMA = "Persistent-TMA"
    DEVICE_TMA = "Device-TMA"
    SCALED_MM = "Scaled-MM"


class ScalingStrategy(Enum):
    PER_TENSOR = "PerTensor"
    PER_ROW = "PerRow"


@dataclass(frozen=True)
class ExperimentConfig:
    M: int
    K: int
    N: int
    scaling_strategy: ScalingStrategy
    fp8_kernel: FP8Kernel
    compile: bool


def get_fp8_matmul(
    A: torch.Tensor,
    B: torch.Tensor,
    scaling_strategy: ScalingStrategy,
    fp8_kernel: FP8Kernel,
):
    A_fp8 = A.to(torch.float8_e4m3fn)
    B_fp8 = B.to(torch.float8_e4m3fn)
    A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
    
    if scaling_strategy == ScalingStrategy.PER_TENSOR:
        a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    elif scaling_strategy == ScalingStrategy.PER_ROW:
        a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
        b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
    else:
        raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")

    assert fp8_kernel == FP8Kernel.SCALED_MM
    return lambda: addmm_float8_unwrapped_inference(
        A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
    )


def run_matmul(config: ExperimentConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
    B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

    fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)


    if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
        fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")


    _ = fp8_matmul()

    return


def main():
    torch.random.manual_seed(123)

    # Define your experiment configuration here
    config = ExperimentConfig(
        M=8192,
        K=8192,
        N=8192,
        scaling_strategy=ScalingStrategy.PER_TENSOR,
        fp8_kernel=FP8Kernel.SCALED_MM,
        compile=True,
    )

    run_matmul(config)


if __name__ == "__main__":
    CLI(main)

Generating the following output code: https://gist.github.com/drisspg/5083632ae57e268a43100555d3890a19

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138536

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 51bbc38 with merge base be90d3c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…l for _scaled_mm"



Code I am using to iterate w/
```Python
import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path

from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
    addmm_float8_unwrapped_inference,
    preprocess_data,
    Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
    matmul_persistent,
    matmul_tma_persistent,
    matmul_device_tma_persistent,
)
from enum import Enum

logging.getLogger("transformer_nuggets").setLevel(logging.INFO)


class FP8Kernel(Enum):
    PERSISTENT = "Persistent"
    PERSISTENT_TMA = "Persistent-TMA"
    DEVICE_TMA = "Device-TMA"
    SCALED_MM = "Scaled-MM"


class ScalingStrategy(Enum):
    PER_TENSOR = "PerTensor"
    PER_ROW = "PerRow"


dataclass(frozen=True)
class ExperimentConfig:
    M: int
    K: int
    N: int
    scaling_strategy: ScalingStrategy
    fp8_kernel: FP8Kernel
    compile: bool


def get_fp8_matmul(
    A: torch.Tensor,
    B: torch.Tensor,
    scaling_strategy: ScalingStrategy,
    fp8_kernel: FP8Kernel,
):
    A_fp8 = A.to(torch.float8_e4m3fn)
    B_fp8 = B.to(torch.float8_e4m3fn)
    A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
    
    if scaling_strategy == ScalingStrategy.PER_TENSOR:
        a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    elif scaling_strategy == ScalingStrategy.PER_ROW:
        a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
        b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
    else:
        raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")

    assert fp8_kernel == FP8Kernel.SCALED_MM
    return lambda: addmm_float8_unwrapped_inference(
        A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
    )


def run_matmul(config: ExperimentConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
    B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

    fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)


    if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
        fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")


    _ = fp8_matmul()

    return


def main():
    torch.random.manual_seed(123)

    # Define your experiment configuration here
    config = ExperimentConfig(
        M=8192,
        K=8192,
        N=8192,
        scaling_strategy=ScalingStrategy.PER_TENSOR,
        fp8_kernel=FP8Kernel.SCALED_MM,
        compile=True,
    )

    run_matmul(config)


if __name__ == "__main__":
    CLI(main)
```


Generating the following output code: https://gist.github.com/drisspg/5083632ae57e268a43100555d3890a19


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 22, 2024
…l for _scaled_mm"



Code I am using to iterate w/
```Python
import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path

from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
    addmm_float8_unwrapped_inference,
    preprocess_data,
    Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
    matmul_persistent,
    matmul_tma_persistent,
    matmul_device_tma_persistent,
)
from enum import Enum

logging.getLogger("transformer_nuggets").setLevel(logging.INFO)


class FP8Kernel(Enum):
    PERSISTENT = "Persistent"
    PERSISTENT_TMA = "Persistent-TMA"
    DEVICE_TMA = "Device-TMA"
    SCALED_MM = "Scaled-MM"


class ScalingStrategy(Enum):
    PER_TENSOR = "PerTensor"
    PER_ROW = "PerRow"


dataclass(frozen=True)
class ExperimentConfig:
    M: int
    K: int
    N: int
    scaling_strategy: ScalingStrategy
    fp8_kernel: FP8Kernel
    compile: bool


def get_fp8_matmul(
    A: torch.Tensor,
    B: torch.Tensor,
    scaling_strategy: ScalingStrategy,
    fp8_kernel: FP8Kernel,
):
    A_fp8 = A.to(torch.float8_e4m3fn)
    B_fp8 = B.to(torch.float8_e4m3fn)
    A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
    
    if scaling_strategy == ScalingStrategy.PER_TENSOR:
        a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    elif scaling_strategy == ScalingStrategy.PER_ROW:
        a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
        b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
    else:
        raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")

    assert fp8_kernel == FP8Kernel.SCALED_MM
    return lambda: addmm_float8_unwrapped_inference(
        A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
    )


def run_matmul(config: ExperimentConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
    B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

    fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)


    if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
        fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")


    _ = fp8_matmul()

    return


def main():
    torch.random.manual_seed(123)

    # Define your experiment configuration here
    config = ExperimentConfig(
        M=8192,
        K=8192,
        N=8192,
        scaling_strategy=ScalingStrategy.PER_TENSOR,
        fp8_kernel=FP8Kernel.SCALED_MM,
        compile=True,
    )

    run_matmul(config)


if __name__ == "__main__":
    CLI(main)
```


Generating the following output code: https://gist.github.com/drisspg/5083632ae57e268a43100555d3890a19


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 25, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 27, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]

import torch
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from torch.utils._triton import has_triton_tma
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg this one checks for existence of the host-side TMA API (i.e., create_1d_tma_descriptor and create_2d_tma_descriptor). Here you need device-side TMA API which seems to have been added in a later PR triton-lang/triton#4633. I guess, we need another helper function checking if experimental_device_tensormap_create1d and experimental_device_tensormap_create2d are available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
tma_size = 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unused.

[ghstack-poisoned]
[ghstack-poisoned]
@drisspg
Copy link
Contributor Author

drisspg commented Nov 5, 2024

This has been my main benchmarking script: drisspg/transformer_nuggets#39

I tried this and for some reason on the second iteration I always run into:

  File "/home/drisspg/meta/triton/python/triton/testing.py", line 120, in do_bench
    fn()
  File "/home/drisspg/meta/pytorch/torch/_inductor/autotune_process.py", line 689, in run_with_workspace
    run_method(
  File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 1026, in run
    return launcher(
           ^^^^^^^^^
  File "<string>", line 13, in launcher
  File "/home/drisspg/meta/triton/python/triton/backends/nvidia/driver.py", line 426, in __call__
    self.launch(*args, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: only integer tensors of a single element can be converted to an index

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

If I set dynamic=False, all works, so clearly there is some dyanmic shape bug getting in here.

@drisspg
Copy link
Contributor Author

drisspg commented Nov 5, 2024

Without Tma store:
without_tma_store.csv
With Tma sto
with_tma_store.csv
re:

@aakhundov
Copy link
Contributor

So there seems to be up to 3% perf improvement from TMA store vs tl.store, mostly pronounced for smaller Ks (which makes sense, as for larger Ks store is less frequent).

@aakhundov
Copy link
Contributor

Could you also share the baseline CSV (w/o any persistent+TMA)?

[ghstack-poisoned]
@drisspg
Copy link
Contributor Author

drisspg commented Nov 5, 2024

@aakhundov definitely much improved results: https://gist.github.com/drisspg/698ea0a7eac8ae4542350f939f05f9d1

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@drisspg drisspg requested review from aakhundov and removed request for aakhundov November 11, 2024 23:10
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@drisspg
Copy link
Contributor Author

drisspg commented Nov 18, 2024

Updated and ran the existing tests getting:
...................................................................unsupported shared memory layout for MMAv3
UNREACHABLE executed at /home/drisspg/meta/triton/python/build/cmake.linux-x86_64-cpython-3.12/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc:400!
Fatal Python error: Aborted

Thread 0x00007f9f8e7fc640 (most recent call first):
File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/compile_tasks.py", line 35 in _reload_python_module
File "/home/drisspg/meta/pytorch/torch/_inductor/codecache.py", line 3084 in load_by_key_path
File "/home/drisspg/meta/pytorch/torch/_inductor/autotune_process.py", line 728 in precompile
File "/home/drisspg/meta/pytorch/torch/_inductor/select_algorithm.py", line 1012 in precompile
File "/home/drisspg/meta/pytorch/torch/_inductor/select_algorithm.py", line 1421 in precompile_with_captured_stdout
File "/home/drisspg/.conda/envs/dev/lib/python3.12/concurrent/futures/thread.py", line 58 in run
File "/home/drisspg/.conda/envs/dev/lib/python3.12/concurrent/futures/thread.py", line 92 in _worker
File "/home/drisspg/.conda/envs/dev/lib/python3.12/threading.py", line 1010 in run
File "/home/drisspg/.conda/envs/dev/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
File "/home/drisspg/.conda/envs/dev/lib/python3.12/threading.py", line 1030 in _bootstrap

I am 057a9d31e16dbc68d954c74fa98910ac4cf1a033 I heard the newer triton main works will try to update

cc @aakhundov

[ghstack-poisoned]
@drisspg
Copy link
Contributor Author

drisspg commented Dec 4, 2024

Closing cause this got stale and ghstack yelled at me: #142045

@drisspg drisspg closed this Dec 4, 2024
@github-actions github-actions bot deleted the gh/drisspg/71/head branch January 4, 2025 02:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants