Skip to content

[SGLang-Diffusion LLM] Add inference support for d3LLM models (arXiv:2601.07568)#20615

Open
flowermouse wants to merge 2 commits intosgl-project:mainfrom
flowermouse:feat/dllm-llada-dream-support
Open

[SGLang-Diffusion LLM] Add inference support for d3LLM models (arXiv:2601.07568)#20615
flowermouse wants to merge 2 commits intosgl-project:mainfrom
flowermouse:feat/dllm-llada-dream-support

Conversation

@flowermouse
Copy link
Copy Markdown

Summary

This PR adds SGLang serving support for d3LLM (arXiv:2601.07568), an ultra-fast diffusion language model based on pseudo-trajectory distillation. d3LLM achieves significantly higher tokens-per-forward (TPF) than vanilla diffusion LLMs while maintaining competitive accuracy, enabling up to 3×-5× end-to-end speedup over autoregressive baselines on H800 and B200.

Two models are supported:

  • d3LLM-LLaDA (8B) — an ultra-fast diffusion LLM, distilled from LLaDA, using full-sequence bidirectional attention
  • d3LLM-Dream (7B) — an ultra-fast diffusion LLM, distilled from Dream, using full-sequence bidirectional attention

Both models require bidirectional attention (instead of the block-causal diffusion in existing LLaDA 2.0/2.1), which demands a new dLLM decoding method support in SGLang.

Key Changes

New Files:

  • models/d3llm_llada.py, models/dream.py: Model implementations for d3LLM-LLaDA and d3LLM-Dream
  • dllm/algorithm/entropy_threshold.py: EntropyThreshold decoding algorithm
  • dllm/algorithm/full_attn_multi_block.py: FullAttnMultiBlock decoding algorithm for d3LLM multi-block parallel decoding with bidirectional attention

Modified Files:

  • dllm/config.py: Add needs_full_prefill and pad_full_generation flags to DllmConfig
  • dllm/mixin/req.py, dllm/mixin/scheduler.py: Handle full-prefill mode in request lifecycle
  • flashinfer_backend.py: Zero out prefix_lens when needs_full_prefill is enabled
  • schedule_batch.py: Skip tree-cache matching for full-prefill models; free old KV slots before re-extend
  • schedule_policy.py: Adjust token budget and truncation for full-prefill mode
  • forward_batch_info.py: Build positions from full seq_len for bidirectional attention
  • cuda_graph_runner.py: Disable CUDA graph for variable-length full-prefill inputs; work around Blackwell (SM≥10) multi-BS capture instability
  • http_server.py: Increase max_new_tokens for dLLM health checks to ensure proper warmup
  • radix_cache.py: Add None guard for node in inc_lock_ref / dec_lock_ref

Tests & Docs:

  • test/registered/dllm/test_dllm_gsm8k.py: GSM8K benchmark test for d3LLM models
  • docs/supported_models/text_generation/diffusion_language_models.md: Updated documentation

Benchmark Results

Dataset: GSM8K-CoT (zero-shot)
Decoding: FullAttnMultiBlock
TP Size: 1

Model Threshold Batch Size B200 TPS H800 TPS A800 TPS TPF Accuracy
d3LLM-LLaDA (8B dense) 0.5 1 1240.99 545.31 251.61 9.91 75.36%
d3LLM-LLaDA (8B dense) 0.5 4 1310.18 551.87 249.98 8.56 75.12%
d3LLM-Dream (7B dense) 0.4 1 586.77 280.48 125.57 4.89 80.89%
d3LLM-Dream (7B dense) 0.4 4 676.81 281.82 127.85 4.22 80.76%

TPS = Tokens Per Second, TPF = Tokens Per Forward (average forward passes per token)

Usage Example

# Launch server with d3LLM-LLaDA
python -m sglang.launch_server \
    --model d3LLM/d3LLM_LLaDA \
    --trust-remote-code \
    --attention-backend flashinfer \
    --dllm-algorithm FullAttnMultiBlock \
    --mem-fraction-static 0.8 \
    --cuda-graph-max-bs 32

# Launch server with d3LLM-Dream
python -m sglang.launch_server \
    --model d3LLM/d3LLM_Dream \
    --trust-remote-code \
    --attention-backend flashinfer \
    --dllm-algorithm FullAttnMultiBlock \
    --mem-fraction-static 0.8 \
    --cuda-graph-max-bs 32

Test Plan

  • Verified accuracy matches HuggingFace reference implementation
  • Tested on B200, H800, A800 GPUs
  • Added test_dllm_gsm8k.py for CI integration
  • Confirmed no regression on existing LLaDA 2.x models

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates d3LLM models into SGLang, significantly enhancing its capability to serve ultra-fast diffusion language models. The core objective is to enable efficient inference for models that utilize full-sequence bidirectional attention, which differs from traditional autoregressive or block-causal diffusion LLMs. This required a fundamental re-evaluation and adjustment of how requests are processed, how KV caches are managed, and how CUDA graphs are utilized, ensuring optimal performance and compatibility with these novel architectures.

Highlights

  • New Model Support: Added inference support for d3LLM models, specifically d3LLM-LLaDA (8B) and d3LLM-Dream (7B), which are ultra-fast diffusion language models based on pseudo-trajectory distillation.
  • New Decoding Algorithm: Introduced the FullAttnMultiBlock decoding algorithm to support the full-sequence bidirectional attention required by d3LLM models, enabling multi-block parallel decoding.
  • Core Infrastructure Changes: Implemented new configuration flags (needs_full_prefill, pad_full_generation) and adjusted request scheduling, KV cache management, and CUDA graph handling to optimize performance for bidirectional, full-prefill models.
  • Performance Benchmarks: Included benchmark results on the GSM8K-CoT dataset, demonstrating significant Tokens Per Second (TPS) and competitive accuracy for d3LLM-LLaDA and d3LLM-Dream on various GPUs.
  • Expanded Documentation and Testing: Updated documentation to reflect the new models and decoding algorithm, and added comprehensive GSM8K benchmark tests for CI integration.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • docs/supported_models/text_generation/diffusion_language_models.md
    • Updated supported DLLM algorithms to include FullAttnMultiBlock.
    • Added configuration details for FullAttnMultiBlock.
    • Expanded the supported models table to include d3LLM-LLaDA and d3LLM-Dream.
  • python/sglang/srt/dllm/algorithm/entropy_threshold.py
    • Added a new EntropyThreshold algorithm for dLLM decoding.
  • python/sglang/srt/dllm/algorithm/full_attn_multi_block.py
    • Added a new FullAttnMultiBlock algorithm for multi-block parallel decoding with full attention and entropy thresholding.
  • python/sglang/srt/dllm/config.py
    • Added needs_full_prefill and pad_full_generation flags to DllmConfig.
    • Updated from_server_args to detect DreamModel and LLaDAModelLM architectures and set needs_full_prefill accordingly.
    • Determined pad_full_generation based on needs_full_prefill and the FullAttnMultiBlock algorithm.
  • python/sglang/srt/dllm/mixin/req.py
    • Added dllm_ids for full-attention bidirectional models.
    • Modified _init_fill_ids_for_dllm to handle full-attention models by padding all max_new_tokens masks or using dllm_ids.
  • python/sglang/srt/dllm/mixin/scheduler.py
    • Modified process_batch_result_dllm to sync output_ids from fill_ids for needs_full_prefill dLLMs.
    • Prevented insertion into radix cache for needs_full_prefill dLLMs.
  • python/sglang/srt/entrypoints/http_server.py
    • Increased max_new_tokens for dLLM health checks and server warmup to ensure proper decoding.
  • python/sglang/srt/layers/attention/flashinfer_backend.py
    • Modified init_forward_metadata_capture_cuda_graph and init_forward_metadata_replay_cuda_graph to set prefix_lens to zero for bidirectional models requiring full prefill.
  • python/sglang/srt/managers/schedule_batch.py
    • Modified init_next_round_input to skip tree cache matching and clear prefix_indices for bidirectional dLLMs.
    • Added logic in prepare_for_extend to free old KV slots for needs_full_prefill models before reallocating.
    • Added dllm_origin_input_lens to ModelWorkerBatch.
  • python/sglang/srt/managers/schedule_policy.py
    • Adjusted rem_dllm_tokens calculation for bidirectional models to use a larger token budget.
    • Modified _get_dllm_remain_tokens and _add_dllm_req to handle full sequence processing for bidirectional models.
    • Added None guard for req.last_node in _req_inc_lock_ref and _lock_node.
  • python/sglang/srt/managers/scheduler.py
    • Added logic to skip cache_unfinished_req for needs_full_prefill dLLMs.
  • python/sglang/srt/mem_cache/radix_cache.py
    • Added node is None guard to inc_lock_ref and dec_lock_ref methods.
  • python/sglang/srt/model_executor/cuda_graph_runner.py
    • Added a workaround for Blackwell GPUs to capture only the maximum batch size for dLLM/speculative modes.
    • Added _disable_dllm_cuda_graph flag.
    • Modified can_run to check for is_dllm_supported for variable-length inputs in full-prefill dLLMs.
  • python/sglang/srt/model_executor/forward_batch_info.py
    • Added dllm_origin_input_lens to ForwardBatch.
    • Modified init_new to build positions covering the full sequence for bidirectional models.
  • python/sglang/srt/models/d3llm_llada.py
    • Added a new model class LLaDAModelLM for d3LLM-LLaDA, extending LlamaForCausalLM.
    • Patched LLaDA config fields, set attention type to ENCODER_ONLY, and implemented weight remapping.
  • python/sglang/srt/models/dream.py
    • Added a new model class DreamModel for d3LLM-Dream, extending Qwen2ForCausalLM.
    • Set attention type to ENCODER_ONLY and implemented logic to right-shift hidden states for correct logits alignment.
  • test/registered/dllm/test_dllm_gsm8k.py
    • Added new tests for d3LLM-LLaDA and d3LLM-Dream models, including GSM8K accuracy and batch size 1 speed benchmarks.
Activity
  • Verified accuracy matches HuggingFace reference implementation.
  • Tested on B200, H800, A800 GPUs.
  • Added test_dllm_gsm8k.py for CI integration.
  • Confirmed no regression on existing LLaDA 2.x models.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a significant feature addition, adding support for d3LLM models, which are fast diffusion language models. The changes span across model implementation, new decoding algorithms, and modifications to the serving core to support bidirectional attention models. The code changes look solid and are accompanied by new tests. My main feedback is on improving the documentation for the new FullAttnMultiBlock algorithm to ensure the configuration keys and their descriptions are consistent with the implementation, which is important for users.

Comment on lines +56 to +70
```yaml
# Confidence threshold for accepting predicted tokens
# Range: 0.0 - 1.0
threshold: 0.5
# Additional threshold increment per decoding step
block_add: 0.1
# Threshold for considering a token as "decoded"
decoded_thresh: 0.95
# Sub-block size for parallel decoding
sub_block_size: 32
# Number of iterations to delay before caching
cache_delay_iter: 2
# Interval for refreshing the attention cache
refresh_interval: 10000
```
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration keys and descriptions for FullAttnMultiBlock in this documentation seem to be inconsistent with the implementation in python/sglang/srt/dllm/algorithm/full_attn_multi_block.py.

Specifically:

  • The key block_add in the YAML should likely be block_add_threshold. Its description "Additional threshold increment per decoding step" is also misleading. Based on the code, it's the "previous block progress threshold to add the next block".
  • The key decoded_thresh should likely be decoded_token_threshold. Its description "Threshold for considering a token as 'decoded'" is also not quite accurate. It's the "previous block progress threshold for full activation".
  • The key sub_block_size is documented here but does not appear to be used in the FullAttnMultiBlock algorithm implementation.

Could you please update the documentation to match the implementation for clarity and correctness? This will help users configure the algorithm correctly.

Suggested change
```yaml
# Confidence threshold for accepting predicted tokens
# Range: 0.0 - 1.0
threshold: 0.5
# Additional threshold increment per decoding step
block_add: 0.1
# Threshold for considering a token as "decoded"
decoded_thresh: 0.95
# Sub-block size for parallel decoding
sub_block_size: 32
# Number of iterations to delay before caching
cache_delay_iter: 2
# Interval for refreshing the attention cache
refresh_interval: 10000
```
# Confidence threshold for accepting predicted tokens
# Range: 0.0 - 1.0
threshold: 0.5
# Previous block progress threshold to add the next block
# Range: 0.0 - 1.0
block_add_threshold: 0.1
# Previous block progress threshold for a block to be considered fully active
# Range: 0.0 - 1.0
decoded_token_threshold: 0.95
# Number of iterations to delay before caching
cache_delay_iter: 2
# Interval for refreshing the attention cache
refresh_interval: 10000

@flowermouse
Copy link
Copy Markdown
Author

/tag-and-rerun-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant