Skip to content

Commit 66e0d48

Browse files
add fp8 sm120
1 parent d935dbd commit 66e0d48

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

qutlass/csrc/gemm.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,18 @@ void matmul_host_mxf8_bf16_tn(torch::Tensor& D,
378378
ElementB, LayoutBTag, AlignmentB>::Gemm, cutlass::float_ue8m0_t
379379
>(D, A, B, A_sf, B_sf, alpha, m, n, k, A.device());
380380
}
381+
#elif TARGET_CUDA_ARCH == 120
382+
using ArchTag = cutlass::arch::Sm120;
383+
384+
using MmaTileShape = Shape<_128,_128,_128>;
385+
using ClusterShape = Shape<_1,_1,_1>;
386+
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
387+
388+
runGemm<FpGemm<MmaTileShape, ClusterShape, PerSmTileShape_MNK,
389+
ArchTag,
390+
ElementA, LayoutATag, AlignmentA,
391+
ElementB, LayoutBTag, AlignmentB>::Gemm, cutlass::float_ue8m0_t
392+
>(D, A, B, A_sf, B_sf, alpha, m, n, k, A.device());
381393
#else
382394
TORCH_CHECK(false, "Unsupported CUDA arch");
383395
#endif

0 commit comments

Comments
 (0)