Skip to content

Commit afed3ed

Browse files
committed
Add unit tests for MoE DenseGEMM
Signed-off-by: Zongfei Jing <[email protected]>
1 parent 03c9912 commit afed3ed

4 files changed

Lines changed: 487 additions & 3 deletions

File tree

tests/unittest/_torch/modules/moe/moe_test_utils.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
TRTLLMGenFusedMoE,
4646
)
4747
from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import DeepGemmFusedMoE
48+
from tensorrt_llm._torch.modules.fused_moe.fused_moe_densegemm import DenseGEMMFusedMoE
4849
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
4950
from tensorrt_llm._torch.utils import ActivationType, is_gated_activation
5051
from tensorrt_llm.models.modeling_utils import QuantAlgo
@@ -62,6 +63,7 @@ class MoeBackendType(str, Enum):
6263
TRTLLM = "TRTLLM"
6364
CUTEDSL = "CUTEDSL"
6465
DEEPGEMM = "DEEPGEMM"
66+
DENSEGEMM = "DENSEGEMM"
6567

6668

6769
def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]:
@@ -71,6 +73,7 @@ def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]:
7173
MoeBackendType.TRTLLM: TRTLLMGenFusedMoE,
7274
MoeBackendType.CUTEDSL: CuteDslFusedMoE,
7375
MoeBackendType.DEEPGEMM: DeepGemmFusedMoE,
76+
MoeBackendType.DENSEGEMM: DenseGEMMFusedMoE,
7477
}
7578
return backend_class_map[backend_type]
7679

@@ -589,6 +592,68 @@ def _cdiv(x, y):
589592
return None
590593

591594

595+
def should_skip_densegemm(
596+
backend_type: MoeBackendType,
597+
quant_algo: Optional[QuantAlgo] = None,
598+
model_config: "MoeModelConfig" = None,
599+
) -> Optional[str]:
600+
"""
601+
Check DenseGEMM backend specific constraints.
602+
603+
DenseGEMM reshapes all expert weights into a single dense matrix and performs
604+
a single large GEMM. It only supports NVFP4 quantization on Blackwell (SM 100/103).
605+
606+
Constraints:
607+
- Only NVFP4 quantization
608+
- hidden_size and intermediate_size must be 128-aligned (NVFP4 requirement)
609+
- top_k must be >= 2 (fc2_alpha scatter requires multiple expert selections)
610+
- num_experts must be > top_k
611+
612+
Returns:
613+
Skip reason string if test should be skipped, None otherwise
614+
"""
615+
if backend_type != MoeBackendType.DENSEGEMM:
616+
return None
617+
618+
# DenseGEMM only supports NVFP4
619+
if quant_algo != QuantAlgo.NVFP4:
620+
return f"DenseGEMMFusedMoE only supports NVFP4 quantization (got quant_algo={quant_algo})"
621+
622+
if model_config is not None:
623+
hidden_size = model_config.hidden_size
624+
intermediate_size = model_config.intermediate_size
625+
626+
# 128-alignment required for NVFP4 dense GEMM kernels
627+
if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
628+
return (
629+
f"DenseGEMMFusedMoE NVFP4 requires 128-aligned sizes "
630+
f"(got h={hidden_size}, i={intermediate_size})"
631+
)
632+
633+
# FC2 DenseGEMM kernel tiles K with MMA tile size 256.
634+
# intermediate_size (= weight_per_expert for FC2) must be 256-aligned
635+
# so expert boundaries align with MMA tile boundaries.
636+
_MMA_TILE_K = 256
637+
if intermediate_size % _MMA_TILE_K != 0:
638+
return (
639+
f"DenseGEMMFusedMoE requires intermediate_size to be a multiple "
640+
f"of {_MMA_TILE_K} (got intermediate_size={intermediate_size}). "
641+
f"FC2 kernel cannot split alpha_scale at non-aligned expert boundaries."
642+
)
643+
644+
# DenseGEMM with very large intermediate_size has accuracy issues vs
645+
# per-expert reference due to FP4 error accumulation in the large
646+
# FC2 reduction dimension (expert_count * intermediate_size).
647+
if intermediate_size >= 14336:
648+
return (
649+
f"[Design Limitation] DenseGEMMFusedMoE NVFP4 with large "
650+
f"intermediate_size={intermediate_size} has accuracy issues "
651+
f"vs per-expert reference due to FP4 error accumulation."
652+
)
653+
654+
return None
655+
656+
592657
def should_skip_multi_gpu(
593658
parallel_mode: str,
594659
model_config: "MoeModelConfig",
@@ -748,15 +813,21 @@ def get_quick_skip_reason(
748813
lambda: should_skip_deepgemm(
749814
backend_type, quant_algo=quant_algo, model_config=model_config
750815
),
816+
lambda: should_skip_densegemm(
817+
backend_type, quant_algo=quant_algo, model_config=model_config
818+
),
751819
]
752820
for check in skip_checks:
753821
skip_reason = check()
754822
if skip_reason:
755823
return skip_reason
756824

757-
# DEEPGEMM: float16 reference module constraint
758-
if backend_type == MoeBackendType.DEEPGEMM and dtype == torch.float16:
759-
return "DeepGemmFusedMoE reference module requires bfloat16 input"
825+
# DEEPGEMM/DENSEGEMM: float16 reference module constraint
826+
if (
827+
backend_type in (MoeBackendType.DEEPGEMM, MoeBackendType.DENSEGEMM)
828+
and dtype == torch.float16
829+
):
830+
return f"{backend_type.value} reference module requires bfloat16 input"
760831

761832
# 128-alignment requirement for quantization
762833
if quant_algo is not None:

tests/unittest/_torch/modules/moe/test_moe_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import itertools
3030
import logging
31+
import os
3132
from typing import List, Optional
3233

3334
import pytest
@@ -220,6 +221,7 @@ def run_backend_moe(
220221
MoeBackendType.TRTLLM,
221222
MoeBackendType.CUTEDSL,
222223
MoeBackendType.DEEPGEMM,
224+
MoeBackendType.DENSEGEMM,
223225
]
224226

225227
# Data types to test
@@ -466,6 +468,10 @@ def test_moe_backend(
466468
3. Different sequence lengths use appropriate tactics
467469
4. swiglu_gptoss_style (SwiGlu with custom parameters) works correctly
468470
"""
471+
# DENSEGEMM: disable fused fc2_alpha path for backend-level testing.
472+
if backend_type == MoeBackendType.DENSEGEMM:
473+
os.environ["TRTLLM_MOE_FUSED_FC2_ALPHA"] = "0"
474+
469475
is_gated = is_gated_activation(activation_type)
470476
swiglu_gptoss_style = False
471477
if is_gated:

tests/unittest/_torch/modules/moe/test_moe_module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def init_worker(custom_paths, comm_method_type):
706706
MoeBackendType.TRTLLM,
707707
MoeBackendType.CUTEDSL,
708708
MoeBackendType.DEEPGEMM,
709+
MoeBackendType.DENSEGEMM,
709710
]
710711

711712
# Data types to test
@@ -1055,6 +1056,10 @@ def test_configurable_moe_single_gpu(
10551056
3. Autotune captures and replays all tactics properly
10561057
4. swiglu_gptoss_style (SwiGLU with custom parameters) works correctly
10571058
"""
1059+
# DENSEGEMM: disable fused fc2_alpha path for testing against per-expert reference.
1060+
if moe_backend == MoeBackendType.DENSEGEMM.value:
1061+
os.environ["TRTLLM_MOE_FUSED_FC2_ALPHA"] = "0"
1062+
10581063
swiglu_gptoss_style = swiglu_alpha != 1 or swiglu_beta != 0 or swiglu_limit != float("inf")
10591064
ci_skip = should_skip_to_accelerate_ci(
10601065
backend_type=MoeBackendType(moe_backend),

0 commit comments

Comments
 (0)