[TRTLLM-12949][refactor] visual_gen: unify fused QK-norm+rope dispatch#14529
Conversation
…e cross_head op
Goal: collapse the Python-side fused QK-norm + RoPE dispatch in visual_gen
to a single op call with no per-model branching, and remove the orphaned
fused_dit_cross_head_qk_norm_rope op once the full-dim template covers
its envelope.
Kernel changes
- fusedDiTQKNormRopeKernel.cu / .h:
• Full-dim template: lift num_heads cap 32 -> 64 (MAX_N = 64 * HEAD_DIM)
so WAN-14B (40 heads * 128) fits. SMEM budget stays well within B200's
227 KB dynamic SMEM cap.
• Per-head template: add cos_seq_per_batch runtime parameter for
kernel-side cos broadcast over B (matching what the full-dim template
already did). Python dispatcher no longer needs host-side .repeat(B, 1)
for FLUX dual-stream CFG paths.
• Delete fusedDiTCrossHeadQKNormRopeKernel + launchFusedDiTCrossHeadQKNormRope
(kernel + launcher + .h declaration).
- fusedDiTSplitNormKernel.cu + fusedDiTSplitQKNormRopeKernel.cu:
• Lift num_heads cap 32 -> 64 (these caps were SMEM-budget driven, not
block-size driven as the misleading comments suggested -- comments
corrected).
C++ op surface
- fusedDiTQKNormRopeOp.cpp:
• Accept cos_emb / sin_emb with rank in [2, 4]; the op flattens internally
based on whether shape[-2] == num_heads, so callers can pass raw cos
tensors without reshaping.
• Drop the per-head launcher's cos_seq_per_batch == 0 reject (kernel
now supports broadcast).
• Remove the fused_dit_cross_head_qk_norm_rope op entirely (schema +
impl + registration).
- fusedDiTSplitQKNormRopeOp.cpp: same rank-flexible cos handling.
Python -- modules/attention.py
- Remove the qk_norm_rope_kernel: str = ... selector parameter and its
validation assert; the C++ op now auto-dispatches by tensor shape/dtype.
- apply_packed_qk_norm_rope / apply_split_norm_rope / apply_split_norm
collapse to ~6-line passthroughs -- no cos shape inference, no
host-side broadcast tile, no kernel-name branching.
Python -- model files
- models/wan/transformer_wan.py: drop the
qk_norm_rope_kernel="fused_dit_cross_head_qk_norm_rope" arg from
attn1's Attention(...). All WAN variants (12 / 24 / 40 heads * 128)
now route through the full-dim template via the default op.
- models/ltx2/transformer_ltx2.py: remove the head_dim in (64, 128)
hardcoded check from both LTX2Attention.forward's guard and
project_kv. The kernel templates already enforce the head_dim
envelope; if a caller constructs LTX2Attention with an unsupported
head_dim the kernel throws a clear TLLM_THROW.
Test / utility hygiene
- _torch/compilation/utils.py: drop fused_dit_cross_head_qk_norm_rope
from the torch.compile op-allowlist mapping.
- tests/integration/defs/examples/test_visual_gen.py: delete the
_disable_wan_fused_qk_norm_rope_if_unavailable runtime fallback helper
and its call site (cross-head op no longer exists, so the
graceful-fallback wrapping is moot).
- tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py:
delete the cross-head test block.
- benchmarks/bench_fused_dit_cross_head_qk_norm_rope.py: delete entire file.
Validation
- Microbenchmark (B200, 30-40 iters, min): the full-dim template is the
fastest fused QK-norm + RoPE kernel for every WAN num_heads value we
measured (12 / 24 / 32 / 40 / 48 / 56 / 64 heads x 128). Direct vs the
deleted cross-head op on representative WAN shapes:
WAN-1.3B 12h, 32K tok, B=1: 0.166 ms vs 0.674 ms -> 4.06x faster
WAN-TI2V-5B 24h, 32K tok, B=1: 0.233 ms vs 0.872 ms -> 3.75x faster
WAN-14B family 40h, 32K tok, B=1: 0.492 ms vs 1.093 ms -> 2.22x faster
CFG=2 (B=2) shapes: the cross-head op fails outright (no batch broadcast
path); the full-dim path handles it natively via cos_seq_per_batch.
- Bit-similarity full-dim vs cross-head on 5 WAN shapes (12 / 24 / 40 / 48
/ 64 heads * 128): max_abs_diff <= 1.6e-2 (1 bf16 ULP at output ~ 1),
mean ~ 1.24e-8.
- tests/unittest/_torch/visual_gen/test_ltx2_attention.py: 6/6 PASS.
- tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py:
55 passed, 24 skipped (per-head + full-dim suites; cross-head suite
deleted with the op).
- LTX-2 nvfp4 single-stage e2e smoke earlier on this branch (40 steps,
768x1280 x 121 frames, 1 GPU): 12.5 s, 12.9 MB mp4. No regression.
Follow-ups (deferred to keep diff focused):
- PR NVIDIA#13978 (async Ulysses, LTX-2 + WAN): move per-model compute_q/k/v
closure into a base Attention.forward_async() now that the closure's
norm+rope step is a single apply_split_norm_rope call.
- Merge fused_dit_split_norm + fused_dit_split_norm_rope into one op.
- Wire WAN cross-attn (attn2) through fused_dit_split_norm instead of
eager apply_qk_norm.
Signed-off-by: Yiyun Lu <[email protected]>
|
/bot run --disable-fail-fast |
|
PR_Github #50199 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR consolidates separate per-head and cross-head fused QK Norm+RoPE kernel paths into a single unified kernel with broadcast-aware RoPE embedding support. The per-head kernel gains ChangesQK Norm+RoPE Kernel Unification and Consolidation
🎯 4 (Complex) | ⏱️ ~45 minutes Possibly Related PRs
Suggested Labels
Suggested Reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu (1)
546-575:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAdd device opt-in dynamic-SMEM guard before full-dim kernel launch
In
cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu,cfg.dynamicSmemBytesis unconditionally applied viacudaFuncSetAttribute(..., cudaFuncAttributeMaxDynamicSharedMemorySize, cfg.dynamicSmemBytes)beforecudaLaunchKernelEx, with no check againstcudaDevAttrMaxSharedMemoryPerBlockOptin; this can cause launch/attribute failures on GPUs that can’t opt in to that dynamic SMEM. Same pattern also appears incpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.cu.Suggested fix
+ int device = 0; + cudaGetDevice(&device); + int max_optin_smem = 0; + cudaDeviceGetAttribute(&max_optin_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + TLLM_CHECK_WITH_INFO(static_cast<int>(cfg.dynamicSmemBytes) <= max_optin_smem, + "fusedDiTQKNormRopeFullDim: requested dynamic SMEM (%zu) exceeds device opt-in limit (%d)", + cfg.dynamicSmemBytes, max_optin_smem); + `#define` LAUNCH(HEAD_DIM, INTERLEAVE, PER_HEAD, COS_T) \ do \ { \ auto* kptr = fusedDiTQKNormFullDimRopeKernel<HEAD_DIM, INTERLEAVE, PER_HEAD, COS_T>; \🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu` around lines 546 - 575, The kernel currently unconditionally calls cudaFuncSetAttribute(...) with cfg.dynamicSmemBytes which can fail on devices that don't support the dynamic-SMEM opt-in; fix by querying the device opt-in limit via cudaDeviceGetAttribute(&optin_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, /*device*/0) and only call cudaFuncSetAttribute for fusedDiTQKNormFullDimRopeKernel (and the analogous split kernel) when optin_bytes > 0 and cfg.dynamicSmemBytes <= (size_t)optin_bytes (or clamp to optin_bytes if you prefer); otherwise skip the cudaFuncSetAttribute call so the subsequent cudaLaunchKernelEx is not given an unsupported dynamic SMEM attribute.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/thop/fusedDiTQKNormRopeOp.cpp`:
- Around line 49-59: Before flattening, ensure the raw shapes of cos_emb and
sin_emb are identical so we don't accidentally pair mismatched rows after
reshape: add a TORCH_CHECK comparing their full sizes (e.g., cos_emb.sizes() ==
sin_emb.sizes()) just after the existing dim checks and before computing
cos_last_raw/fold_last_two/cos_new_last; if you need to allow the two supported
layouts, explicitly validate that corresponding leading dimensions match (and if
fold_last_two logic is used, ensure both size(-2) and size(-1) match
expectations on cos_emb and sin_emb) so reshape({-1, cos_new_last}) cannot
silently misalign rows. Use the same error style/messages as existing
TORCH_CHECKs and reference cos_emb, sin_emb, fold_last_two, cos_last_raw, and
cos_new_last in your checks.
In `@cpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpp`:
- Around line 37-47: Add a raw-shape equality check for cos_emb/sin_emb before
flattening so differently ordered layouts cannot slip through: verify that
sin_emb.sizes() == cos_emb.sizes() (or at minimum that every dimension except
the flattened-leading dims matches when computing fold_last_two) and reject with
TORCH_CHECK if they differ; then proceed to compute fold_last_two and produce
cos_2d and sin_2d as before. Reference cos_emb, sin_emb, fold_last_two,
cos_last_raw, cos_new_last, cos_2d and sin_2d to locate where to insert this
validation.
---
Outside diff comments:
In `@cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu`:
- Around line 546-575: The kernel currently unconditionally calls
cudaFuncSetAttribute(...) with cfg.dynamicSmemBytes which can fail on devices
that don't support the dynamic-SMEM opt-in; fix by querying the device opt-in
limit via cudaDeviceGetAttribute(&optin_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin, /*device*/0) and only call
cudaFuncSetAttribute for fusedDiTQKNormFullDimRopeKernel (and the analogous
split kernel) when optin_bytes > 0 and cfg.dynamicSmemBytes <=
(size_t)optin_bytes (or clamp to optin_bytes if you prefer); otherwise skip the
cudaFuncSetAttribute call so the subsequent cudaLaunchKernelEx is not given an
unsupported dynamic SMEM attribute.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 34ce75a0-4795-4b1f-baf7-b6dc5d5e8c27
📒 Files selected for processing (13)
benchmarks/bench_fused_dit_cross_head_qk_norm_rope.pycpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cucpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.hcpp/tensorrt_llm/kernels/fusedDiTSplitNormKernel.cucpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.cucpp/tensorrt_llm/thop/fusedDiTQKNormRopeOp.cppcpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpptensorrt_llm/_torch/compilation/utils.pytensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.pytensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytensorrt_llm/_torch/visual_gen/modules/attention.pytests/integration/defs/examples/test_visual_gen.pytests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py
💤 Files with no reviewable changes (4)
- tensorrt_llm/_torch/compilation/utils.py
- benchmarks/bench_fused_dit_cross_head_qk_norm_rope.py
- tests/integration/defs/examples/test_visual_gen.py
- tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py
…ape equality Both fused_dit_qk_norm_rope and fused_dit_split_norm_rope previously checked only rank parity and post-flatten 2D-shape equality. That permits silently-bad layouts (e.g. cos=[B,S,D] vs sin=[S,B,D]) to reshape to the same [B*S, D] buffer but pair the wrong sin row with each cos row, rotating the embedding incorrectly. Reject mismatched raw shapes before flatten so this fails fast with a clear error. Signed-off-by: Yiyun Lu <[email protected]>
|
PR_Github #50199 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
…tests
The fused split-norm kernel only supports head_dim in {64, 128}. Earlier in this PR I removed the Python-side 'head_dim in (64, 128)' guards from LTX2Attention.forward and project_kv on the assumption that 'production always uses 64 or 128, so the guard is dead code'. That assumption was wrong for the test suite: tests/unittest/_torch/visual_gen/test_ltx2_transformer.py builds mini LTX-2 models with head_dim=32 (TestLTX2AudioVideoModel, TestLTX2VideoOnlyModel, TestLTX2CUDAGraphCapture, TestLTX2TextCache, TestLTX2CacheDiTWrapperPassthrough — 6 failing tests) to keep unit-test runtime down. Without the Python-side guard those tests hit the kernel's clear TLLM_THROW. Restore the head_dim guard in both call sites so the eager path is reachable for head_dim != 64/128, and update _forward_unfused docstring to mention the mini-config path explicitly.
Signed-off-by: Yiyun Lu <[email protected]>
|
PR_Github #50231 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #50232 [ run ] triggered by Bot. Commit: |
|
PR_Github #50231 [ run ] completed with state |
|
PR_Github #50232 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #50271 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #50305 [ run ] triggered by Bot. Commit: |
|
PR_Github #50271 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #50314 [ run ] triggered by Bot. Commit: |
|
PR_Github #50305 [ run ] completed with state |
|
PR_Github #50314 [ run ] completed with state |
NVIDIA#14529) Signed-off-by: Yiyun Lu <[email protected]>
@coderabbitai summary
Description
Unify the visual_gen fused QK-Norm + RoPE kernel surface so the Python dispatcher has no per-model branching — the C++ op auto-dispatches by tensor shapes/dtypes — and delete the now-redundant
fused_dit_cross_head_qk_norm_ropeop +fusedDiTCrossHeadQKNormRopeKernel.Why
The cross-head kernel existed only because the full-dim template capped
num_heads ≤ 32and couldn't fit WAN-14B (40 heads). That cap is cheap to lift; once lifted, the cross-head op has no remaining caller, runs slower on every shape we measured, and outright fails on the CFG=2 case the full-dim template handles natively. Removing it also lets us delete the per-model routing (qk_norm_rope_kernelstring selector onAttention, WAN's explicit opt-in, LTX-2'shead_dim in (64, 128)gate).Key changes
num_headscap 32 → 64 on full-dim template + the two split-norm kernels (SMEM budget already fits at 64h × 128 on B200). Addcos_seq_per_batchruntime param to the per-head kernel so the Python side stops needing host-side.repeat(B, 1)for FLUX dual-stream CFG.fused_dit_qk_norm_ropeandfused_dit_split_norm_ropeaccept cos/sin with rank in [2, 4] and flatten internally. Delete thefused_dit_cross_head_qk_norm_ropeop (schema + impl + registration) and its kernel/launcher.qk_norm_rope_kernelselector from baseAttention;apply_packed_qk_norm_rope/apply_split_norm_rope/apply_split_normcollapse to ~6-line passthroughs. WAN drops the cross-head opt-in; LTX-2 drops the head_dim hardcode.Net diff: +131 / −980 lines, 13 files modified (1 fully removed).
Test Coverage
Microbenchmark — full-dim (this PR) vs cross-head (deleted)
Reproducing the shape set from the original cross-head PR (#13052) on B200, single GPU, B=1, 200 iters, CUDA events. Cross-head measured against the pre-deletion build (op still registered); full-dim measured against this PR's build (cap lifted to 64).
RoPE: interleave (WAN-1)
RoPE: rotate_half (WAN-2)
Full-dim is faster or equal in every shape; the practical shapes (≥ 4k tokens) are 2.0–3.9× faster. At 256 tokens both kernels are launch-overhead-bound. CFG=2 (B=2) cases additionally serve only on the full-dim path — the cross-head op has no batch broadcast and fails outright. Bit-similarity vs cross-head before removal:
max_abs_diff ≤ 1.6e-2(1 bf16 ULP at output ≈ 1).Unit + e2e tests
tests/unittest/_torch/visual_gen/test_ltx2_attention.py— 6/6 PASStests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py— 55 passed, 24 skippedFollow-ups (separate PRs)
compute_q/k/vclosure into a baseAttention.forward_async()— the closure's norm + rope step is oneapply_split_norm_ropecall after this PR.attn2) throughfused_dit_split_norminstead of eagerapply_qk_norm.PR Checklist
PR description clearly explains what and why.
PR Follows TRT-LLM CODING GUIDELINES.
Test cases are provided for new code paths.
Any new dependencies have been scanned for license and vulnerabilities.
CODEOWNERS updated if ownership changes.
Documentation updated as needed.
Update tava architecture diagram if significant design change.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.