4545 TRTLLMGenFusedMoE ,
4646)
4747from tensorrt_llm ._torch .modules .fused_moe .fused_moe_deepgemm import DeepGemmFusedMoE
48+ from tensorrt_llm ._torch .modules .fused_moe .fused_moe_densegemm import DenseGEMMFusedMoE
4849from tensorrt_llm ._torch .modules .fused_moe .interface import MoE
4950from tensorrt_llm ._torch .utils import ActivationType , is_gated_activation
5051from tensorrt_llm .models .modeling_utils import QuantAlgo
@@ -62,6 +63,7 @@ class MoeBackendType(str, Enum):
6263 TRTLLM = "TRTLLM"
6364 CUTEDSL = "CUTEDSL"
6465 DEEPGEMM = "DEEPGEMM"
66+ DENSEGEMM = "DENSEGEMM"
6567
6668
6769def get_backend_class (backend_type : MoeBackendType ) -> Type [MoE ]:
@@ -71,6 +73,7 @@ def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]:
7173 MoeBackendType .TRTLLM : TRTLLMGenFusedMoE ,
7274 MoeBackendType .CUTEDSL : CuteDslFusedMoE ,
7375 MoeBackendType .DEEPGEMM : DeepGemmFusedMoE ,
76+ MoeBackendType .DENSEGEMM : DenseGEMMFusedMoE ,
7477 }
7578 return backend_class_map [backend_type ]
7679
@@ -589,6 +592,68 @@ def _cdiv(x, y):
589592 return None
590593
591594
595+ def should_skip_densegemm (
596+ backend_type : MoeBackendType ,
597+ quant_algo : Optional [QuantAlgo ] = None ,
598+ model_config : "MoeModelConfig" = None ,
599+ ) -> Optional [str ]:
600+ """
601+ Check DenseGEMM backend specific constraints.
602+
603+ DenseGEMM reshapes all expert weights into a single dense matrix and performs
604+ a single large GEMM. It only supports NVFP4 quantization on Blackwell (SM 100/103).
605+
606+ Constraints:
607+ - Only NVFP4 quantization
608+ - hidden_size and intermediate_size must be 128-aligned (NVFP4 requirement)
609+ - top_k must be >= 2 (fc2_alpha scatter requires multiple expert selections)
610+ - num_experts must be > top_k
611+
612+ Returns:
613+ Skip reason string if test should be skipped, None otherwise
614+ """
615+ if backend_type != MoeBackendType .DENSEGEMM :
616+ return None
617+
618+ # DenseGEMM only supports NVFP4
619+ if quant_algo != QuantAlgo .NVFP4 :
620+ return f"DenseGEMMFusedMoE only supports NVFP4 quantization (got quant_algo={ quant_algo } )"
621+
622+ if model_config is not None :
623+ hidden_size = model_config .hidden_size
624+ intermediate_size = model_config .intermediate_size
625+
626+ # 128-alignment required for NVFP4 dense GEMM kernels
627+ if hidden_size % 128 != 0 or intermediate_size % 128 != 0 :
628+ return (
629+ f"DenseGEMMFusedMoE NVFP4 requires 128-aligned sizes "
630+ f"(got h={ hidden_size } , i={ intermediate_size } )"
631+ )
632+
633+ # FC2 DenseGEMM kernel tiles K with MMA tile size 256.
634+ # intermediate_size (= weight_per_expert for FC2) must be 256-aligned
635+ # so expert boundaries align with MMA tile boundaries.
636+ _MMA_TILE_K = 256
637+ if intermediate_size % _MMA_TILE_K != 0 :
638+ return (
639+ f"DenseGEMMFusedMoE requires intermediate_size to be a multiple "
640+ f"of { _MMA_TILE_K } (got intermediate_size={ intermediate_size } ). "
641+ f"FC2 kernel cannot split alpha_scale at non-aligned expert boundaries."
642+ )
643+
644+ # DenseGEMM with very large intermediate_size has accuracy issues vs
645+ # per-expert reference due to FP4 error accumulation in the large
646+ # FC2 reduction dimension (expert_count * intermediate_size).
647+ if intermediate_size >= 14336 :
648+ return (
649+ f"[Design Limitation] DenseGEMMFusedMoE NVFP4 with large "
650+ f"intermediate_size={ intermediate_size } has accuracy issues "
651+ f"vs per-expert reference due to FP4 error accumulation."
652+ )
653+
654+ return None
655+
656+
592657def should_skip_multi_gpu (
593658 parallel_mode : str ,
594659 model_config : "MoeModelConfig" ,
@@ -748,15 +813,21 @@ def get_quick_skip_reason(
748813 lambda : should_skip_deepgemm (
749814 backend_type , quant_algo = quant_algo , model_config = model_config
750815 ),
816+ lambda : should_skip_densegemm (
817+ backend_type , quant_algo = quant_algo , model_config = model_config
818+ ),
751819 ]
752820 for check in skip_checks :
753821 skip_reason = check ()
754822 if skip_reason :
755823 return skip_reason
756824
757- # DEEPGEMM: float16 reference module constraint
758- if backend_type == MoeBackendType .DEEPGEMM and dtype == torch .float16 :
759- return "DeepGemmFusedMoE reference module requires bfloat16 input"
825+ # DEEPGEMM/DENSEGEMM: float16 reference module constraint
826+ if (
827+ backend_type in (MoeBackendType .DEEPGEMM , MoeBackendType .DENSEGEMM )
828+ and dtype == torch .float16
829+ ):
830+ return f"{ backend_type .value } reference module requires bfloat16 input"
760831
761832 # 128-alignment requirement for quantization
762833 if quant_algo is not None :
0 commit comments