Skip to content

Add Qwen3MoeForCausalLM support#383

Closed
Mandark-droid wants to merge 2 commits intocactus-compute:mainfrom
Mandark-droid:qwen3-moe
Closed

Add Qwen3MoeForCausalLM support#383
Mandark-droid wants to merge 2 commits intocactus-compute:mainfrom
Mandark-droid:qwen3-moe

Conversation

@Mandark-droid
Copy link
Copy Markdown

@Mandark-droid Mandark-droid commented Feb 21, 2026

Add Qwen3MoeForCausalLM support

Adds native Qwen3 MoE (Mixture of Experts) support to the Cactus v1.x graph engine. Uses the generalized moe_layer() graph operation from #374, combining QwenModel's decoder (GQA, QK norm, RoPE, INT8 KV cache) with SwiGLU (3-weight) expert FFN routing.

What

C++ model (model_qwen_moe.cpp, model.h)

  • Qwen3MoeModel class using gb->moe_layer() — matches the LFM2MoE pattern
  • ExpertWeights struct (w1/w3/w2) consistent with upstream convention
  • QK norm (per-head RMSNorm on Q, K before RoPE)
  • INT8 KV cache via attention_int8_hybrid (standard path)
  • MoE detection via config fields (num_experts > 0) — no new enum needed

Engine integration (engine_model.cpp)

  • model_type=qwen3_moe parsed as ModelType::QWEN, factory branches on num_experts > 0 to create Qwen3MoeModel
  • Config parsing for moe_intermediate_size, norm_topk_prob, num_experts_per_tok

Python conversion (config_utils.py, converter.py, weight_patterns.py, tensor_io.py)

  • Detect Qwen3MoeForCausalLM / qwen3_moe model type
  • Extract num_experts_per_tok, moe_intermediate_size, norm_topk_prob
  • Per-expert SwiGLU weight iteration (supports both individual and fused tensor formats)
  • FP16 auto-promotion for MoE router weights

Sampling fixes (kernel_nn.cpp, kernel.h, graph_core.cpp, graph.h)

  • Move token_history to file scope with clear_sample_history() to prevent cross-model contamination
  • Fix greedy sampling (temperature=0) to use pure argmax
  • FP32 softmax support for MoE router weight precision

Architecture support

Parameterized for the full Qwen3 MoE family:

Parameter Supported Range
Expert count 16–32
Top-k routing 2–8
Hidden size 512–2048
Head dim 64–128
RoPE theta 1M–10M
Expert FFN SwiGLU (gate × silu + up → down)

Benchmarks

Loggenix-MoE-0.62B (16 experts, top-2, 512 hidden, 12 layers) on Android ARM64 (Pixel 7a, Tensor G2):

Test Tokens TTFT Decode TPS
Greedy (temp=0) 32 77ms 23.8ms/tok 42.1
Chat (2+2) 64 169ms 26.3ms/tok 38.1
128-token generation 128 105ms 27.3ms/tok 36.6
Multi-turn 5 248ms 26.1ms/tok 38.4
50-token stability 50 86ms 26.7ms/tok 37.4

Model init: ~352ms. All 5 tests passing.

Zero new graph ops

All operations already exist: TOPK, SOFTMAX, MATMUL, SILU, MULTIPLY, RMSNORM, ROPE, INDEX, plus the generalized moe_layer from #374.

@Mandark-droid Mandark-droid force-pushed the qwen3-moe branch 3 times, most recently from dd57ef9 to 509cca1 Compare February 21, 2026 13:20
@Mandark-droid
Copy link
Copy Markdown
Author

Hey @HenryNdubuaku @rshemet — this PR adds native Qwen3 MoE (Qwen3MoeForCausalLM) support to Cactus. It's been rebased on top of #374 (Karen's LFM2 MoE) so both can land cleanly. Tested end-to-end on Android ARM64 with the Loggenix 0.6B model — prefill and decode produce coherent output. Would appreciate a review when you get a chance!

@HenryNdubuaku
Copy link
Copy Markdown
Collaborator

thanks so much for this @Mandark-droid , one thing, you often wanna wait for pending PRS to be merged, else you might build on faulty code. For instance, I refactored that PR to have a generalised moe_layer, and merged, so you now have to work with that rather than recreating everything :(

Mandark-droid added a commit to Mandark-droid/cactus that referenced this pull request Feb 23, 2026
Addresses review feedback on PR cactus-compute#383: refactors the Qwen3MoeForCausalLM
implementation to use the generalized moe_layer graph operation introduced
in PR cactus-compute#374, instead of a custom per-expert loop.

Key changes:
- build_mlp uses gb->moe_layer() matching the LFM2MoEModel pattern
- WeightNodeIDs uses ExpertWeights struct (w1/w3/w2) consistent with upstream
- Weight file naming follows upstream convention (moe_expert_ prefix)
- MoE detection via config fields (num_experts > 0) like LFM2, no new enum
- Attention uses INT8 KV cache with attention_int8_hybrid (standard path)
- QK normalization per-head before RoPE (Qwen3 architecture requirement)

Bug fixes included:
- Fix greedy sampling (temperature=0) to use pure argmax regardless of top_p/top_k
- Move token_history to file scope with clear_sample_history() to prevent
  cross-model sampling contamination
- Add FP32 softmax support for MoE router weight precision
- Fix config parsing to strip \r\n from values

Python conversion:
- Detect Qwen3MoeForCausalLM model type
- Per-expert SwiGLU weight extraction (fused and individual tensor formats)
- FP16 auto-promotion for MoE router weights

Signed-off-by: Mandark-droid <[email protected]>

https://claude.ai/code/session_01SFkVPWXCCtTTpmMwsj274A
Addresses review feedback on PR cactus-compute#383: refactors the Qwen3MoeForCausalLM
implementation to use the generalized moe_layer graph operation introduced
in PR cactus-compute#374, instead of a custom per-expert loop.

Key changes:
- build_mlp uses gb->moe_layer() matching the LFM2MoEModel pattern
- WeightNodeIDs uses ExpertWeights struct (w1/w3/w2) consistent with upstream
- Weight file naming follows upstream convention (moe_expert_ prefix)
- MoE detection via config fields (num_experts > 0) like LFM2, no new enum
- Attention uses INT8 KV cache with attention_int8_hybrid (standard path)
- QK normalization per-head before RoPE (Qwen3 architecture requirement)

Bug fixes included:
- Fix greedy sampling (temperature=0) to use pure argmax regardless of top_p/top_k
- Move token_history to file scope with clear_sample_history() to prevent
  cross-model sampling contamination
- Add FP32 softmax support for MoE router weight precision
- Fix config parsing to strip \r\n from values

Python conversion:
- Detect Qwen3MoeForCausalLM model type
- Per-expert SwiGLU weight extraction (fused and individual tensor formats)
- FP16 auto-promotion for MoE router weights

Signed-off-by: Mandark-droid <[email protected]>
@Mandark-droid
Copy link
Copy Markdown
Author

Thanks for the feedback @HenryNdubuaku — good lesson learned on building on pending PRs. I've rebased on main and refactored to use the generalized moe_layer() from #374. The Qwen3 MoE build_mlp now calls gb->moe_layer() directly, matching the LFM2MoE pattern instead of a custom per-expert loop.

Validated on Android ARM64 (Pixel 7a, Tensor G2) with the Loggenix 0.6B model — all tests passing:

Test Tokens TTFT Decode TPS
Greedy (temp=0) 32 77ms 23.8ms/tok 42.1
Chat (2+2) 64 169ms 26.3ms/tok 38.1
128-token generation 128 105ms 27.3ms/tok 36.6
Multi-turn 5 248ms 26.1ms/tok 38.4
50-token stability 50 86ms 26.7ms/tok 37.4

Model init: ~352ms. Ready for another look when you get a chance!

@Mandark-droid
Copy link
Copy Markdown
Author

Hi @HenryNdubuaku , thanks again for the feedback on PR #383. I wanted to follow up since it was closed rather than merged — was there a specific reason for that? I'm wondering if the Qwen3 MoE changes and Loggenix model support ended up in main through a different route, or if there's something I should adjust to get it across. Let me know what makes sense for next steps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants