Skip to content

[TRTLLM-12949][refactor] visual_gen: unify fused QK-norm+rope dispatch#14529

Merged
luyiyun1021 merged 3 commits into
NVIDIA:mainfrom
luyiyun1021:dev/qk-norm-rope-fusion-survey
May 26, 2026
Merged

[TRTLLM-12949][refactor] visual_gen: unify fused QK-norm+rope dispatch#14529
luyiyun1021 merged 3 commits into
NVIDIA:mainfrom
luyiyun1021:dev/qk-norm-rope-fusion-survey

Conversation

@luyiyun1021
Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 commented May 25, 2026

@coderabbitai summary

Description

Unify the visual_gen fused QK-Norm + RoPE kernel surface so the Python dispatcher has no per-model branching — the C++ op auto-dispatches by tensor shapes/dtypes — and delete the now-redundant fused_dit_cross_head_qk_norm_rope op + fusedDiTCrossHeadQKNormRopeKernel.

Why

The cross-head kernel existed only because the full-dim template capped num_heads ≤ 32 and couldn't fit WAN-14B (40 heads). That cap is cheap to lift; once lifted, the cross-head op has no remaining caller, runs slower on every shape we measured, and outright fails on the CFG=2 case the full-dim template handles natively. Removing it also lets us delete the per-model routing (qk_norm_rope_kernel string selector on Attention, WAN's explicit opt-in, LTX-2's head_dim in (64, 128) gate).

Key changes

  • Kernels: lift num_heads cap 32 → 64 on full-dim template + the two split-norm kernels (SMEM budget already fits at 64h × 128 on B200). Add cos_seq_per_batch runtime param to the per-head kernel so the Python side stops needing host-side .repeat(B, 1) for FLUX dual-stream CFG.
  • C++ ops: fused_dit_qk_norm_rope and fused_dit_split_norm_rope accept cos/sin with rank in [2, 4] and flatten internally. Delete the fused_dit_cross_head_qk_norm_rope op (schema + impl + registration) and its kernel/launcher.
  • Python: drop qk_norm_rope_kernel selector from base Attention; apply_packed_qk_norm_rope / apply_split_norm_rope / apply_split_norm collapse to ~6-line passthroughs. WAN drops the cross-head opt-in; LTX-2 drops the head_dim hardcode.
  • Test / utility hygiene: drop the op from the torch.compile allowlist, delete the cross-head unit tests, delete the standalone benchmark, delete the runtime fallback helper in the integration tests.

Net diff: +131 / −980 lines, 13 files modified (1 fully removed).

Test Coverage

Microbenchmark — full-dim (this PR) vs cross-head (deleted)

Reproducing the shape set from the original cross-head PR (#13052) on B200, single GPU, B=1, 200 iters, CUDA events. Cross-head measured against the pre-deletion build (op still registered); full-dim measured against this PR's build (cap lifted to 64).

RoPE: interleave (WAN-1)

Config cross-head (ms) full-dim (ms) speedup
WAN-1.3B 12h × 128, 256 tok 0.0066 0.0066 1.00×
WAN-1.3B 12h × 128, 4096 tok 0.0728 0.0202 3.60×
WAN-14B 40h × 128, 256 tok 0.0095 0.0066 1.44×
WAN-14B 40h × 128, 4096 tok 0.1251 0.0607 2.06×
WAN-14B 40h × 128, 16384 tok 0.5473 0.2443 2.24×

RoPE: rotate_half (WAN-2)

Config cross-head (ms) full-dim (ms) speedup
WAN-1.3B 12h × 128, 256 tok 0.0078 0.0065 1.20×
WAN-1.3B 12h × 128, 4096 tok 0.0901 0.0232 3.88×
WAN-14B 40h × 128, 256 tok 0.0120 0.0067 1.79×
WAN-14B 40h × 128, 4096 tok 0.1561 0.0732 2.13×
WAN-14B 40h × 128, 16384 tok 0.6585 0.2913 2.26×

Full-dim is faster or equal in every shape; the practical shapes (≥ 4k tokens) are 2.0–3.9× faster. At 256 tokens both kernels are launch-overhead-bound. CFG=2 (B=2) cases additionally serve only on the full-dim path — the cross-head op has no batch broadcast and fails outright. Bit-similarity vs cross-head before removal: max_abs_diff ≤ 1.6e-2 (1 bf16 ULP at output ≈ 1).

Unit + e2e tests

  • tests/unittest/_torch/visual_gen/test_ltx2_attention.py6/6 PASS
  • tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py55 passed, 24 skipped

Follow-ups (separate PRs)

  1. PR [TRTLLM-11457][feat] Async Ulysses pipeline (Enabled for LTX-2 + WAN) #13978 (async Ulysses, LTX-2 + WAN) can now move the per-model compute_q/k/v closure into a base Attention.forward_async() — the closure's norm + rope step is one apply_split_norm_rope call after this PR.
  2. Wire WAN cross-attn (attn2) through fused_dit_split_norm instead of eager apply_qk_norm.

PR Checklist

  • PR description clearly explains what and why.

  • PR Follows TRT-LLM CODING GUIDELINES.

  • Test cases are provided for new code paths.

  • Any new dependencies have been scanned for license and vulnerabilities.

  • CODEOWNERS updated if ownership changes.

  • Documentation updated as needed.

  • Update tava architecture diagram if significant design change.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

…e cross_head op

Goal: collapse the Python-side fused QK-norm + RoPE dispatch in visual_gen
to a single op call with no per-model branching, and remove the orphaned
fused_dit_cross_head_qk_norm_rope op once the full-dim template covers
its envelope.

Kernel changes
- fusedDiTQKNormRopeKernel.cu / .h:
  • Full-dim template: lift num_heads cap 32 -> 64 (MAX_N = 64 * HEAD_DIM)
    so WAN-14B (40 heads * 128) fits. SMEM budget stays well within B200's
    227 KB dynamic SMEM cap.
  • Per-head template: add cos_seq_per_batch runtime parameter for
    kernel-side cos broadcast over B (matching what the full-dim template
    already did). Python dispatcher no longer needs host-side .repeat(B, 1)
    for FLUX dual-stream CFG paths.
  • Delete fusedDiTCrossHeadQKNormRopeKernel + launchFusedDiTCrossHeadQKNormRope
    (kernel + launcher + .h declaration).
- fusedDiTSplitNormKernel.cu + fusedDiTSplitQKNormRopeKernel.cu:
  • Lift num_heads cap 32 -> 64 (these caps were SMEM-budget driven, not
    block-size driven as the misleading comments suggested -- comments
    corrected).

C++ op surface
- fusedDiTQKNormRopeOp.cpp:
  • Accept cos_emb / sin_emb with rank in [2, 4]; the op flattens internally
    based on whether shape[-2] == num_heads, so callers can pass raw cos
    tensors without reshaping.
  • Drop the per-head launcher's cos_seq_per_batch == 0 reject (kernel
    now supports broadcast).
  • Remove the fused_dit_cross_head_qk_norm_rope op entirely (schema +
    impl + registration).
- fusedDiTSplitQKNormRopeOp.cpp: same rank-flexible cos handling.

Python -- modules/attention.py
- Remove the qk_norm_rope_kernel: str = ... selector parameter and its
  validation assert; the C++ op now auto-dispatches by tensor shape/dtype.
- apply_packed_qk_norm_rope / apply_split_norm_rope / apply_split_norm
  collapse to ~6-line passthroughs -- no cos shape inference, no
  host-side broadcast tile, no kernel-name branching.

Python -- model files
- models/wan/transformer_wan.py: drop the
  qk_norm_rope_kernel="fused_dit_cross_head_qk_norm_rope" arg from
  attn1's Attention(...). All WAN variants (12 / 24 / 40 heads * 128)
  now route through the full-dim template via the default op.
- models/ltx2/transformer_ltx2.py: remove the head_dim in (64, 128)
  hardcoded check from both LTX2Attention.forward's guard and
  project_kv. The kernel templates already enforce the head_dim
  envelope; if a caller constructs LTX2Attention with an unsupported
  head_dim the kernel throws a clear TLLM_THROW.

Test / utility hygiene
- _torch/compilation/utils.py: drop fused_dit_cross_head_qk_norm_rope
  from the torch.compile op-allowlist mapping.
- tests/integration/defs/examples/test_visual_gen.py: delete the
  _disable_wan_fused_qk_norm_rope_if_unavailable runtime fallback helper
  and its call site (cross-head op no longer exists, so the
  graceful-fallback wrapping is moot).
- tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py:
  delete the cross-head test block.
- benchmarks/bench_fused_dit_cross_head_qk_norm_rope.py: delete entire file.

Validation
- Microbenchmark (B200, 30-40 iters, min): the full-dim template is the
  fastest fused QK-norm + RoPE kernel for every WAN num_heads value we
  measured (12 / 24 / 32 / 40 / 48 / 56 / 64 heads x 128). Direct vs the
  deleted cross-head op on representative WAN shapes:
    WAN-1.3B 12h, 32K tok, B=1:      0.166 ms vs 0.674 ms -> 4.06x faster
    WAN-TI2V-5B 24h, 32K tok, B=1:   0.233 ms vs 0.872 ms -> 3.75x faster
    WAN-14B family 40h, 32K tok, B=1: 0.492 ms vs 1.093 ms -> 2.22x faster
  CFG=2 (B=2) shapes: the cross-head op fails outright (no batch broadcast
  path); the full-dim path handles it natively via cos_seq_per_batch.
- Bit-similarity full-dim vs cross-head on 5 WAN shapes (12 / 24 / 40 / 48
  / 64 heads * 128): max_abs_diff <= 1.6e-2 (1 bf16 ULP at output ~ 1),
  mean ~ 1.24e-8.
- tests/unittest/_torch/visual_gen/test_ltx2_attention.py: 6/6 PASS.
- tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py:
  55 passed, 24 skipped (per-head + full-dim suites; cross-head suite
  deleted with the op).
- LTX-2 nvfp4 single-stage e2e smoke earlier on this branch (40 steps,
  768x1280 x 121 frames, 1 GPU): 12.5 s, 12.9 MB mp4. No regression.

Follow-ups (deferred to keep diff focused):
- PR NVIDIA#13978 (async Ulysses, LTX-2 + WAN): move per-model compute_q/k/v
  closure into a base Attention.forward_async() now that the closure's
  norm+rope step is a single apply_split_norm_rope call.
- Merge fused_dit_split_norm + fused_dit_split_norm_rope into one op.
- Wire WAN cross-attn (attn2) through fused_dit_split_norm instead of
  eager apply_qk_norm.

Signed-off-by: Yiyun Lu <[email protected]>
@luyiyun1021 luyiyun1021 requested review from a team as code owners May 25, 2026 09:33
@luyiyun1021 luyiyun1021 requested a review from liji-nv May 25, 2026 09:33
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50199 [ run ] triggered by Bot. Commit: 6be3718 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 25, 2026

📝 Walkthrough

Walkthrough

This PR consolidates separate per-head and cross-head fused QK Norm+RoPE kernel paths into a single unified kernel with broadcast-aware RoPE embedding support. The per-head kernel gains cos_seq_per_batch parameter for batch-wise embedding broadcasting, the cross-head launcher is removed entirely, kernel MAX_N capacity increases from 32 to 64 HEAD_DIM, and Python APIs are simplified to rely on kernel auto-dispatch instead of host-side kernel selection and tensor reshaping.

Changes

QK Norm+RoPE Kernel Unification and Consolidation

Layer / File(s) Summary
Kernel broadcast-aware RoPE embedding support
cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu
Per-head kernel parameter list extends to accept tokens_per_batch and cos_seq_per_batch for broadcast control; RoPE embedding indexing computes cos_tokenIdx using cos_seq_per_batch when enabled, supporting broadcasted embeddings across batches.
Launcher signature and kernel launch updates
cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.h, cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu
launchFusedDiTQKNormRope declaration and implementation updated with cos_seq_per_batch parameter and adjusted stream parameter ordering; kernel launch arguments wired to pass the new broadcast-control parameters.
Remove cross-head launcher and update full-dim constraints
cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu
Cross-head launcher launchFusedDiTCrossHeadQKNormRope removed; fusedDiTQKNormFullDimRopeKernel MAX_N increased from 32 * HEAD_DIM to 64 * HEAD_DIM; per-head validation constraint relaxed from num_heads_q <= 32 to num_heads_q <= 64.
Expand MAX_N across split norm kernels
cpp/tensorrt_llm/kernels/fusedDiTSplitNormKernel.cu, cpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.cu
MAX_N doubled from 32 * HEAD_DIM to 64 * HEAD_DIM in both split norm kernel templates; launcher constraints updated to enforce num_heads <= 64 alignment with expanded kernel capacity.
Torch operator unification with multi-rank reshape support
cpp/tensorrt_llm/thop/fusedDiTQKNormRopeOp.cpp, cpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpp
fused_dit_qk_norm_rope and fused_dit_split_norm_rope now accept cos/sin with rank 2–4, internally flatten to 2D, auto-detect broadcast layout, compute cos_seq_per_batch, and wire reshaped tensors to kernel. Removes fused_dit_cross_head_qk_norm_rope operator registration.
Simplify Attention module API and remove kernel selection
tensorrt_llm/_torch/visual_gen/modules/attention.py
Removes qk_norm_rope_kernel parameter from Attention.__init__. Simplifies apply_packed_qk_norm_rope by removing cross-head branching and pre-tiling, passing raw embeddings directly to torch op for auto-dispatch. Simplifies apply_split_norm_rope by removing host-side reshape; inlines flattening in apply_split_norm.
Model integration updates for unified kernel
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py, tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
LTX2 removes head_dim-gated fallback and always uses fused path when qk_norm enabled; fallback routing simplified to depend only on fuse_qk_norm_rope flag. WAN removes explicit qk_norm_rope_kernel override. Updated comments reflect new head-dimension assumptions.
Remove cross-head test cases and benchmarks
benchmarks/bench_fused_dit_cross_head_qk_norm_rope.py, tests/integration/defs/examples/test_visual_gen.py, tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py, tensorrt_llm/_torch/compilation/utils.py
Deletes entire cross-head benchmark script; removes cross-head reference/wrapper functions and five test cases from unit tests; removes conditional fusion-disable helper from WAN integration tests; deletes cross-head operator entry from inplace compilation map.

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly Related PRs

  • NVIDIA/TensorRT-LLM#13985: Modifies fusedDiTQKNormRopeKernel.cu to extend fused RMSNorm+RoPE kernel with cos_seq_per_batch broadcast-control parameter; directly supports the main PR's unified broadcast-aware kernel interface.

Suggested Labels

VisualGen

Suggested Reviewers

  • liji-nv
  • yibinl-nvidia
  • karljang
  • hyukn
  • tomeras91
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 51.85% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The PR title accurately summarizes the main refactoring objective: unifying fused QK-norm+rope dispatch in the visual_gen module and removing the redundant cross_head operator.
Description check ✅ Passed The pull request description provides comprehensive detail on the refactoring scope, rationale, technical changes, test coverage, and validation results with detailed microbenchmark tables.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu (1)

546-575: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add device opt-in dynamic-SMEM guard before full-dim kernel launch

In cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu, cfg.dynamicSmemBytes is unconditionally applied via cudaFuncSetAttribute(..., cudaFuncAttributeMaxDynamicSharedMemorySize, cfg.dynamicSmemBytes) before cudaLaunchKernelEx, with no check against cudaDevAttrMaxSharedMemoryPerBlockOptin; this can cause launch/attribute failures on GPUs that can’t opt in to that dynamic SMEM. Same pattern also appears in cpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.cu.

Suggested fix
+    int device = 0;
+    cudaGetDevice(&device);
+    int max_optin_smem = 0;
+    cudaDeviceGetAttribute(&max_optin_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
+    TLLM_CHECK_WITH_INFO(static_cast<int>(cfg.dynamicSmemBytes) <= max_optin_smem,
+        "fusedDiTQKNormRopeFullDim: requested dynamic SMEM (%zu) exceeds device opt-in limit (%d)",
+        cfg.dynamicSmemBytes, max_optin_smem);
+
 `#define` LAUNCH(HEAD_DIM, INTERLEAVE, PER_HEAD, COS_T)                                                                  \
     do                                                                                                                 \
     {                                                                                                                  \
         auto* kptr = fusedDiTQKNormFullDimRopeKernel<HEAD_DIM, INTERLEAVE, PER_HEAD, COS_T>;                           \
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu` around lines 546 - 575,
The kernel currently unconditionally calls cudaFuncSetAttribute(...) with
cfg.dynamicSmemBytes which can fail on devices that don't support the
dynamic-SMEM opt-in; fix by querying the device opt-in limit via
cudaDeviceGetAttribute(&optin_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin,
/*device*/0) and only call cudaFuncSetAttribute for
fusedDiTQKNormFullDimRopeKernel (and the analogous split kernel) when
optin_bytes > 0 and cfg.dynamicSmemBytes <= (size_t)optin_bytes (or clamp to
optin_bytes if you prefer); otherwise skip the cudaFuncSetAttribute call so the
subsequent cudaLaunchKernelEx is not given an unsupported dynamic SMEM
attribute.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/tensorrt_llm/thop/fusedDiTQKNormRopeOp.cpp`:
- Around line 49-59: Before flattening, ensure the raw shapes of cos_emb and
sin_emb are identical so we don't accidentally pair mismatched rows after
reshape: add a TORCH_CHECK comparing their full sizes (e.g., cos_emb.sizes() ==
sin_emb.sizes()) just after the existing dim checks and before computing
cos_last_raw/fold_last_two/cos_new_last; if you need to allow the two supported
layouts, explicitly validate that corresponding leading dimensions match (and if
fold_last_two logic is used, ensure both size(-2) and size(-1) match
expectations on cos_emb and sin_emb) so reshape({-1, cos_new_last}) cannot
silently misalign rows. Use the same error style/messages as existing
TORCH_CHECKs and reference cos_emb, sin_emb, fold_last_two, cos_last_raw, and
cos_new_last in your checks.

In `@cpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpp`:
- Around line 37-47: Add a raw-shape equality check for cos_emb/sin_emb before
flattening so differently ordered layouts cannot slip through: verify that
sin_emb.sizes() == cos_emb.sizes() (or at minimum that every dimension except
the flattened-leading dims matches when computing fold_last_two) and reject with
TORCH_CHECK if they differ; then proceed to compute fold_last_two and produce
cos_2d and sin_2d as before. Reference cos_emb, sin_emb, fold_last_two,
cos_last_raw, cos_new_last, cos_2d and sin_2d to locate where to insert this
validation.

---

Outside diff comments:
In `@cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu`:
- Around line 546-575: The kernel currently unconditionally calls
cudaFuncSetAttribute(...) with cfg.dynamicSmemBytes which can fail on devices
that don't support the dynamic-SMEM opt-in; fix by querying the device opt-in
limit via cudaDeviceGetAttribute(&optin_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin, /*device*/0) and only call
cudaFuncSetAttribute for fusedDiTQKNormFullDimRopeKernel (and the analogous
split kernel) when optin_bytes > 0 and cfg.dynamicSmemBytes <=
(size_t)optin_bytes (or clamp to optin_bytes if you prefer); otherwise skip the
cudaFuncSetAttribute call so the subsequent cudaLaunchKernelEx is not given an
unsupported dynamic SMEM attribute.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 34ce75a0-4795-4b1f-baf7-b6dc5d5e8c27

📥 Commits

Reviewing files that changed from the base of the PR and between 2e3a75c and 6be3718.

📒 Files selected for processing (13)
  • benchmarks/bench_fused_dit_cross_head_qk_norm_rope.py
  • cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu
  • cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.h
  • cpp/tensorrt_llm/kernels/fusedDiTSplitNormKernel.cu
  • cpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.cu
  • cpp/tensorrt_llm/thop/fusedDiTQKNormRopeOp.cpp
  • cpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpp
  • tensorrt_llm/_torch/compilation/utils.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
  • tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
  • tensorrt_llm/_torch/visual_gen/modules/attention.py
  • tests/integration/defs/examples/test_visual_gen.py
  • tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py
💤 Files with no reviewable changes (4)
  • tensorrt_llm/_torch/compilation/utils.py
  • benchmarks/bench_fused_dit_cross_head_qk_norm_rope.py
  • tests/integration/defs/examples/test_visual_gen.py
  • tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py

Comment thread cpp/tensorrt_llm/thop/fusedDiTQKNormRopeOp.cpp
Comment thread cpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpp
@luyiyun1021 luyiyun1021 changed the title [None][refactor] visual_gen: unify fused QK-norm+rope dispatch; remove cross_head op [TRTLLM-12949][refactor] visual_gen: unify fused QK-norm+rope dispatch; remove cross_head op May 25, 2026
@luyiyun1021 luyiyun1021 changed the title [TRTLLM-12949][refactor] visual_gen: unify fused QK-norm+rope dispatch; remove cross_head op [TRTLLM-12949][refactor] visual_gen: unify fused QK-norm+rope dispatch May 25, 2026
…ape equality

Both fused_dit_qk_norm_rope and fused_dit_split_norm_rope previously checked only rank parity and post-flatten 2D-shape equality. That permits silently-bad layouts (e.g. cos=[B,S,D] vs sin=[S,B,D]) to reshape to the same [B*S, D] buffer but pair the wrong sin row with each cos row, rotating the embedding incorrectly. Reject mismatched raw shapes before flatten so this fails fast with a clear error.

Signed-off-by: Yiyun Lu <[email protected]>
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50199 [ run ] completed with state SUCCESS. Commit: 6be3718
/LLM/main/L0_MergeRequest_PR pipeline #39738 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

…tests

The fused split-norm kernel only supports head_dim in {64, 128}. Earlier in this PR I removed the Python-side 'head_dim in (64, 128)' guards from LTX2Attention.forward and project_kv on the assumption that 'production always uses 64 or 128, so the guard is dead code'. That assumption was wrong for the test suite: tests/unittest/_torch/visual_gen/test_ltx2_transformer.py builds mini LTX-2 models with head_dim=32 (TestLTX2AudioVideoModel, TestLTX2VideoOnlyModel, TestLTX2CUDAGraphCapture, TestLTX2TextCache, TestLTX2CacheDiTWrapperPassthrough — 6 failing tests) to keep unit-test runtime down. Without the Python-side guard those tests hit the kernel's clear TLLM_THROW. Restore the head_dim guard in both call sites so the eager path is reachable for head_dim != 64/128, and update _forward_unfused docstring to mention the mini-config path explicitly.

Signed-off-by: Yiyun Lu <[email protected]>
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50231 [ run ] triggered by Bot. Commit: 0ed611b Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50232 [ run ] triggered by Bot. Commit: 342a044 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50231 [ run ] completed with state ABORTED. Commit: 0ed611b

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50232 [ run ] completed with state SUCCESS. Commit: 342a044
/LLM/main/L0_MergeRequest_PR pipeline #39768 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50271 [ run ] triggered by Bot. Commit: 342a044 Link to invocation

Copy link
Copy Markdown
Member

@zhenhuaw-me zhenhuaw-me left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luyiyun1021 luyiyun1021 requested a review from hyukn May 26, 2026 05:59
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50305 [ run ] triggered by Bot. Commit: 342a044 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50271 [ run ] completed with state ABORTED. Commit: 342a044
/LLM/main/L0_MergeRequest_PR pipeline #39801 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50314 [ run ] triggered by Bot. Commit: 342a044 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50305 [ run ] completed with state ABORTED. Commit: 342a044

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50314 [ run ] completed with state SUCCESS. Commit: 342a044
/LLM/main/L0_MergeRequest_PR pipeline #39843 completed with status: 'SUCCESS'

CI Report

Link to invocation

@luyiyun1021 luyiyun1021 merged commit 1f8312d into NVIDIA:main May 26, 2026
7 checks passed
bmarimuthu-nv pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request May 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants