Skip to content

[TRTLLM-11457][feat] Async Ulysses pipeline (Enabled for LTX-2 + WAN)#13978

Open
luyiyun1021 wants to merge 20 commits into
NVIDIA:mainfrom
luyiyun1021:dev-ltx2-ulysses-async-a2a-pipeline
Open

[TRTLLM-11457][feat] Async Ulysses pipeline (Enabled for LTX-2 + WAN)#13978
luyiyun1021 wants to merge 20 commits into
NVIDIA:mainfrom
luyiyun1021:dev-ltx2-ulysses-async-a2a-pipeline

Conversation

@luyiyun1021
Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 commented May 11, 2026

@coderabbitai summary

Description

Adds an opt-in async Ulysses pipeline for sequence parallelism on diffusion video models (LTX-2 + WAN 2.2). The pipeline overlaps per-rank V/Q/K compute (GEMM → RMSNorm → RoPE) with cross-rank data movement (PyTorch _SymmetricMemory CUDA-IPC peer access + symm-mem barrier release fence) on a single dedicated per-device side stream. SMs stay free for the next V/Q/K while the GPU Copy Engine drains the prior V/Q/K's peer push.

Opt-in gate: parallel_config.async_ulysses (default False). Both LTX-2 (LTX2Attention.__init__(async_ulysses=...)) and WAN 2.2 self-attn read the same gate. Non-Ulysses, audio, cross-attn, and ulysses_size == 1 paths are unchanged.

Requires PyTorch ≥ 2.5 for torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p / rendezvous / barrier. No NCCL device-API dependency.

Enabling

YAML config (trtllm-serve --config foo.yaml):

parallel_config:
  ulysses_size: 4         # any P > 1
  async_ulysses: true     # default False — flip to enable the async pipeline

Python API (VisualGenArgs):

from tensorrt_llm._torch.visual_gen.config import VisualGenArgs, VisualGenParallelConfig

args = VisualGenArgs(
    parallel_config=VisualGenParallelConfig(
        ulysses_size=4,
        async_ulysses=True,
    ),
    ...
)

A/B comparison against the baseline blocking-A2A path needs zero code change — flip the bool.

Design

Stream lanes per attention call

default      [V_GEMM+norm+rope+permute]   [Q_GEMM+norm+rope+permute]   [K_GEMM+norm+rope+permute]                            [SDPA]
                      │ev.record                  │ev.record                  │ev.record
                      ▼                           ▼                           ▼                                                ▲
side_stream      [V_push]                    [Q_push]                    [K_push]   [barrier × 3] ─► ev_done.record ──────────┘
                                                                                                     default waits ev_done

Issue order V → Q → K. V's CE push overlaps with Q's compute; Q's push overlaps with K's compute. Three barriers fire only at _join_async, after which the default stream waits the tail event before SDPA. Side stream is a per-device singleton (UlyssesAttention._side_stream_by_device) so all transformer layers on the same device share one comm lane — avoids per-layer stream proliferation under cuda_graph capture / inductor.

C++ ops surface

// Phase 1 — default stream: lazy-alloc slot ring + fused permute+scatter kernel.
std::tuple<Tensor, SendHandle> ulysses_a2a_async_prepare(Tensor input, ProcessGroup pg);

// Phase 2a — side stream: CE peer push only (no barrier).
void ulysses_a2a_async_push(SendHandle send_h, ProcessGroup pg);

// Phase 2b — side stream: symm-mem barrier release fence (channel=0).
//   Called N times at join, once per deferred push.
void ulysses_a2a_async_barrier(ProcessGroup pg);

AsyncUlyssesOp caches a canonical SymmetricMemory handle on the first slot allocation (mCanonicalHandle); emitBarrier() uses it directly instead of scanning slots each call.

forward_async API (closure-based)

@torch.compiler.disable(recursive=False)
def _issue_async(self, perm_4d):
    recv, send_h = torch.ops.trtllm.ulysses_a2a_async_prepare(perm_4d, self._pg_boxed)
    ev = torch.cuda.Event(); ev.record()
    with torch.cuda.stream(self._async_side_stream):
        ev.wait()
        torch.ops.trtllm.ulysses_a2a_async_push(send_h, self._pg_boxed)
    self._pending_barriers += 1
    return recv

@torch.compiler.disable(recursive=False)
def _join_async(self):
    with torch.cuda.stream(self._async_side_stream):
        for _ in range(self._pending_barriers):
            torch.ops.trtllm.ulysses_a2a_async_barrier(self._pg_boxed)
        ev_done = torch.cuda.Event(); ev_done.record()
    self._pending_barriers = 0
    torch.cuda.current_stream().wait_event(ev_done)

def forward_async(self, compute_q, compute_k, compute_v, **attn_kwargs):
    v_4d = compute_v(); v_5d = self._issue_async(v_4d)
    q_4d = compute_q(); q_5d = self._issue_async(q_4d)
    k_4d = compute_k(); k_5d = self._issue_async(k_4d)
    self._join_async()
    # ulyssesPostUnscatterKernel ▶ SDPA ▶ reverse a2a — all in the outer compile region.

@torch.compiler.disable(recursive=False) on _issue_async / _join_async is the only stream-switch boundary. The caller's compute_q/k/v closures execute on the default stream between successive _issue_async graph breaks, so their bodies sit inside the outer block's inductor graph and get full GEMM+norm+RoPE fusion (verified nvjet_sm100_ootst_..._Avec16UE4M3_Bvec16UE4M3_TNT kernels emitted under NVFP4).

_SymmetricMemory slot ring

AsyncUlyssesOp owns a 3-slot ring (V/Q/K per layer — intra-layer slot-reuse hazard requires ≥ 3, cross-layer reuse is safe because _join_async syncs the side stream before next layer's V starts). First call at (numel, dtype) triggers empty_strided_p2p + rendezvous; subsequent calls reuse the cached (handle, peer_ptrs). Ops cached by pg->getGroupName() so multiple PGs (e.g. ulysses subgroup × cfg subgroup) coexist cleanly. Allocation is not graph-capture-safe (rendezvous / cudaMalloc inside) — forward_async requires one out-of-capture warmup pass before cuda_graph capture starts.

Key Optimizations

The four mechanisms below are what make this PR's async path actually faster than the baseline blocking A2A. Listed bottom-up from hardware-level resource choice to host-side launch reduction.

1. Copy-Engine peer push (avoid SM contention)

V/Q/K peer-to-peer data movement runs entirely on the GPU Copy Engines, never on SM-resident NCCL kernels:

  • Eager path: one cudaMemcpyBatchAsync over the P-1 per-peer chunks (driver fans out across CEs).
  • Capture path: per-peer cudaMemcpyAsync loop (because cudaMemcpyBatchAsync isn't graph-capture-safe until CUDA 13.5 — see Follow-ups Bump onnx from 1.12.0 to 1.13.0 #1).

The Copy Engine handles all P-1 peer pushes, leaving SMs fully free for the next V/Q/K GEMM → RMSNorm → RoPE. Compared to NCCL all-to-all (Send/Recv kernels on SMs competing for the same compute resource as the GEMMs), the CE-based push has disjoint hardware resources from compute — no channel multiplexing or ring-topology hacks needed.

2. Fuse permute with self-chunk scatter (bypass CE local-D2D bandwidth cap)

ulyssesPermuteScatterKernel does two things in one launch:

  • Permute [B, S, H, D] → [P, B, S, H/P, D]
  • Write the local rank's self chunk directly to the recv slot's self position (slot.basePtr + my_rank * chunk_bytes); only the P-1 peer chunks go to slot.sendBuf.

Without the fuse, the alternative is: SM permute → sendBuf; then CE runs P memcpys (P-1 peer pushes + 1 local sendBuf → recv self-slot D2D copy). The CE's local D2D channel is hardware-capped at ~100 GB/s (independent from the NVLink/NVSwitch peer channel at ~400+ GB/s — they don't share bandwidth, but local D2D's own ceiling is low). Since the self chunk is the same size as one peer chunk, the local copy would become the critical path of the entire a2a — the P-1 peer pushes finish 4× faster while we wait for one local D2D.

Fusing moves the self-chunk write into the SM kernel, using HBM bandwidth (~5–10 TB/s) — one to two orders of magnitude above CE local-D2D. The alternative of green-context isolation for the local copy (separate CE queue) requires complex multi-stream/queue coordination; fusing bypasses that whole subsystem.

3. Deferred barriers (push×N FIFO, barrier×N drain at join)

V/Q/K's a2a is split into two C++ ops, ulysses_a2a_async_push (CE push only) and ulysses_a2a_async_barrier (symm-mem barrier release fence). Side-stream sequence:

[push_V, push_Q, push_K, barrier, barrier, barrier]

Originally the op was push+barrier fused, producing [push, barrier, push, barrier, push, barrier] where each barrier kernel cut the CE FIFO and forced the next push to wait. Each barrier kernel drops from ~164 µs → ~93 µs (-43%) under the deferred scheme, and per-step CE PTOP throughput improves by ~4 ms (WAN 8-GPU step capture). Python state tracks _pending_barriers on UlyssesAttention; _join_async drains exactly N barriers on the side stream before the default stream waits the tail event.

4. Dedup fp4_quantize across QKV self-attn

Under async ulysses the video self-attn is forced to SEPARATE_QKV (so V/Q/K can stream-pipeline independently), but with NVFP4 static quant the three Linears all share input_scale:

# Verified bit-equal across WAN 36+36 / LTX-2 42 self-attn layers (modelopt invariant)
to_q.input_scale == to_k.input_scale == to_v.input_scale

Pre-quantize hidden_states once and share the resulting Fp4QuantizedTensor across all three Linears via the existing fp4-shortcut path in NVFP4LinearMethod._input_prepare (which just unpacks .fp4_tensor / .scaling_factor for GEMM — no internal swizzle so the SF layout is the same as Linear's own quantize path). Eliminates 2 of 3 quantize launches per self-attn block (verified −368 launches in WAN A14B 8-GPU step capture; identical kernel count between async-OFF FUSE_QKV and async-ON SEPARATE_QKV+dedup).

Gating:

  • Structural (_maybe_share_qkv_quantize, set in __init__): SEPARATE_QKV + NVFP4 + not force_dynamic_quantization.
  • Runtime: getattr(to_q, "input_scale", None) is not None — guards against checkpoints that exclude individual Linears from NVFP4 (e.g. LTX-2 transformer_blocks.10.attn1, bf16-only — falls back to per-Linear quantize).

Benchmarks

Isolated all-to-all (B200, single rank)

bench_alltoall_ab.py — single-tensor a2a latency (ms), B=2, S_total=6144, H=32, D=128, warmup=30, bench=100. naive_fused = dist.all_to_all_single on 5D fused QKV; naive_split = 3 separate 4D a2a; symm_ce = the new symm-mem CE pipeline (per V/Q/K call).

P mode naive_fused naive_split symm_ce
2 eager 0.805 (1.00×) 0.872 (0.92×) 0.641 (1.26×)
2 graph 0.862 (1.00×) 0.889 (0.97×) 0.685 (1.26×)
4 eager 0.419 (1.00×) 0.483 (0.87×) 0.339 (1.24×)
4 graph 0.431 (1.00×) 0.496 (0.87×) 0.370 (1.16×)
8 eager 0.251 (1.00×) 0.294 (0.85×) 0.194 (1.29×)
8 graph 0.262 (1.00×) 0.309 (0.85×) 0.287 (0.91×)

symm_ce is 1.16–1.29× faster than the fused baseline in every cell except P=8 graph (where cuda_graph capture amortizes Python launch overhead in the naive_fused 1-collective path enough that the merged baseline wins).

End-to-end on B200 (production scale)

model ws uly cfg OFF (s) ON (s) saved E2E
LTX-2 (768×1280, 121 f, 40 steps, NVFP4, torch.compile+cuda_graph) 2 2 1 21.453 20.767 0.686 −3.20%
4 2 2 13.154 12.884 0.270 −2.05%
8 4 2 9.486 9.338 0.148 −1.56%
WAN 2.2 T2V-A14B (720×1280, 81 f, 40 steps, NVFP4 static, torch.compile only) 2 2 1 206.731 198.311 8.420 −4.07%
4 2 2 103.031 100.504 2.527 −2.45%
8 4 2 55.444 54.114 1.330 −2.40%

3 timed E2E runs/cell, median reported. Per-cell range is well below the OFF→ON delta in every case (non-overlap). WAN's larger relative gain reflects more compute-per-A2A and a longer overlap window per layer (no cuda_graph overhead to amortize the baseline path).

Files

File Role
cpp/tensorrt_llm/kernels/ulyssesPermuteScatterKernel.{h,cu} (new) Fused pre-A2A permute + self-chunk-direct-write kernel.
cpp/tensorrt_llm/kernels/ulyssesPostUnscatterKernel.{h,cu} (new) Fused post-A2A unscatter (5D → 4D, joint Q/K/V); NHD-stride storage, HND via transpose-view.
cpp/tensorrt_llm/thop/asyncUlyssesOp.cpp (new) SendHandle + AsyncUlyssesOp (symm-mem slot ring + mCanonicalHandle cache + CE push + barrier). Ops: ulysses_a2a_async_prepare, ulysses_a2a_async_push, ulysses_a2a_async_barrier.
cpp/tensorrt_llm/thop/ulyssesPermuteScatterOp.cpp (new) Thop wrapper for permute+scatter kernel (unit test entry).
cpp/tensorrt_llm/thop/ulyssesPostUnscatterOp.cpp (new) Thop wrapper for post-unscatter kernel.
cpp/tensorrt_llm/thop/CMakeLists.txt Register new thop files.
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py register_fake for trtllm::ulysses_post_unscatter_qkv.
tensorrt_llm/_torch/visual_gen/config.py New async_ulysses Pydantic field (default False).
tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py UlyssesAttention(async_pipeline=...), side-stream singleton, _issue_async/_join_async deferred-barrier pipeline, V/Q/K rolling forward_async.
tensorrt_llm/_torch/visual_gen/modules/attention.py Plumb async_ulysses flag; QKV fp4_quantize dedup in Attention.forward_async (shared Fp4QuantizedTensor across to_q/to_k/to_v).
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py LTX-2 self-attn opts into async path; mirrored QKV dedup in LTX2Attention.forward_async override.
tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py Sigma-sync fix for cuda_graph compat.
tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py WAN self-attn (attn1) reads the gate; SEPARATE_QKV when on.
tests/unittest/_torch/thop/parallel_hw_agnostic/test_ulysses_permute_scatter.py (new) Unit test for permute+scatter kernel.
tests/unittest/_torch/thop/parallel_hw_agnostic/test_ulysses_post_unscatter.py (new) Unit test for post-unscatter kernel.
tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_async.py (new) Multi-rank ws=2 pipeline tests (slot-ring wraparound, cuda_graph capture, multi-PG).

Follow-ups

  1. cudaMemcpyBatchAsync under CUDA Graph capture (CUDA 13.5). Today the capture branch falls back to a P-1-iteration per-peer cudaMemcpyAsync loop because cudaMemcpyBatchAsync is not graph-capture-safe. Per NVBugs 5853376 and NVBugs 5760690, targeted at CUDA 13.5 (Q4 2026). We tried cudaGraphAddMemcpyNode1D with sibling-DAG independence as a stop-gap; nsys confirmed runtime serializes the sibling memcpys on the same HW queue (~4 µs gaps, no CE fanout), so the per-peer loop wins today.

  2. Collapse N barriers at join → 1 barrier. Current _join_async drains N symm-mem barriers in a loop (one per deferred push) — kept 1:1 with the original protocol for risk-free refactor. The barriers are semantically identical (all channel-0 PG-level fences); a single barrier after [push_V, push_Q, push_K] should sync the recv buffers just as well (side-stream FIFO + cross-rank barrier covers all prior pushes). Saves 2 of 3 barrier_kernel per layer (~186 µs × N_layers per step on B200). Requires verifying that all ranks symmetrically issue 1 barrier per join (slot rotation is independent of barrier count — nextSlotIdx ticks in _prepare, not _barrier).

  3. Drop @torch.compiler.disable, native multi-stream torch.compile. Stream-switch boundaries in _issue_async/_join_async are guarded by @torch.compiler.disable(recursive=False) because inductor doesn't yet model with cuda.stream(...) + event.record/wait_event as first-class graph ops (4 graph breaks per layer × N layers). Inline once torch.compile's multi-stream support matures.

  4. Fuse pre-A2A permute into the RMSNorm+RoPE kernel. We already own the fused RMSNorm+RoPE kernel (see fuse_qk_norm_rope=True path). Passing the slot's send_buf pointer + permute index map directly into the kernel epilogue lets it write the permuted layout straight into the symm-mem slot, eliminating the separate ulyssesPermuteScatterKernel launch on the critical path. Leaves the post-unscatter kernel untouched.

Test Coverage

  • Kernel unit tests (single-GPU) under tests/unittest/_torch/thop/parallel_hw_agnostic/: test_ulysses_permute_scatter.py and test_ulysses_post_unscatter.py validate the two new fused CUDA kernels against torch references in isolation.
  • Multi-rank pipeline test (ws=2) under tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_async.py: three torch.multiprocessing.spawn tests covering the production prepare/push/barrier path end-to-end —
    • test_slot_ring_wraparound: eager loop > kNumSlots*2 iterations, byte-exact vs all_to_all_4d. Catches off-by-one slot-reuse bugs.
    • test_capture_smoke: warm slot out-of-capture, then torch.cuda.CUDAGraph replay 8× with fresh inputs, byte-exact vs all_to_all_4d. Exercises the per-peer cudaMemcpyAsync path used under production cuda_graph.
    • test_multi_pg: two distinct PGs spanning the same ranks, alternate calls, byte-exact vs each PG's all_to_all_4d. Exercises PG-name caching in getOrCreateOp and set_group_info re-registration across groups.
  • Recommended CI bot stages: DGX_B200-4_GPUs-PyTorch-1, DGX_B200-8_GPUs-PyTorch-1 (covers Ulysses ws=4 and ws=8 in both LTX-2 and WAN end-to-end).

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@luyiyun1021 luyiyun1021 requested review from a team as code owners May 11, 2026 07:57
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47692 [ run ] triggered by Bot. Commit: 856e7fa Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 11, 2026

📝 Walkthrough

Walkthrough

This PR implements an asynchronous split-QKV pipeline for Ulysses sequence parallelism. It adds NCCL LSA barrier support, a hybrid all-to-all v11 service combining peer copies with self-copy streaming, Green Context stream partitioning, and integrates pipelined Q/K/V pre-computation into LTX2 attention with rank synchronization via LSA barriers.

Changes

Ulysses Split-QKV LSA Barrier Pipeline

Layer / File(s) Summary
LSA Barrier Contract & Implementation
cpp/tensorrt_llm/kernels/lsaBarrierKernel.h, cpp/tensorrt_llm/kernels/lsaBarrierKernel.cu
Declares and implements NCCL 2.28+ LSA barrier factory and emit methods with thread-safe slot rotation via atomic operations; provides fallback stub for older NCCL versions.
All-to-All V11 Hybrid Service
cpp/tensorrt_llm/thop/alltoallOp.cpp (lines 18–721)
Bootstraps NCCL communicators via Torch ProcessGroup, implements hybrid copy execution combining CUDA memcpy for peers with configurable self-copy streaming, adds window-based slot management and LSA barrier emission for rank synchronization.
All-to-All Torch Ops & Dispatch
cpp/tensorrt_llm/thop/alltoallOp.cpp (lines 741–1001), tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Registers ulysses_alltoall_hybrid_symm and ulysses_lsa_barrier custom Torch ops with CUDA dispatch and CompositeExplicitAutograd routing; adds fake shape-inference stubs for torch.compile support.
Green Context Stream Management
tensorrt_llm/_torch/distributed/_ulysses_gc.py
Implements raw CUDA driver bindings to split SM resources, create two Green Context partitions with separate streams, and provides per-device singleton with fallback to standard Torch streams on creation failure or environment override.
Ulysses Attention Pipeline Implementation
tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py (lines 23–449)
Adds V/Q/K reshape/permute helpers for 5D slot layouts, caches streams/ProcessGroup/ranks on UlyssesAttention, implements barrier emission and hybrid alltoall issue methods, and schedules pre-alltoall compute across multiple streams with event-based synchronization and windowed peer copies.
LTX2 Transformer Model Integration
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py, tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py
Adds ulysses_async_a2a flag to enable split-QKV pipeline for self-attention, wraps inner attention with UlyssesAttention when active, pre-creates and attaches pipeline streams and compiled functions, routes forward calls to forward_with_pipeline with post-permute reshaping; fixes to_velocity to keep tensor sigma on-device and avoid synchronization deadlock.

Sequence Diagram(s)

sequenceDiagram
  participant Client as LTX2Attention
  participant Pipeline as Ulysses Pipeline
  participant VCompute as V Compute<br/>(gc_comp stream)
  participant QKCompute as Q/K Compute<br/>(pri_comm stream)
  participant SelfCopy as Self-Copy<br/>(gc_selfcopy stream)
  participant Peers as Peer Copy<br/>(memcpy)
  participant Barrier as LSA Barrier
  participant Backend as Inner Backend SDPA
  
  Client->>Pipeline: forward_with_pipeline(x, to_q, to_k, to_v, ...)
  par V Stream
    Pipeline->>VCompute: norm_v + to_v (5D slot layout)
    VCompute-->>Pipeline: v_5d
  and Q/K Stream
    Pipeline->>QKCompute: norm_q/k + rope + to_q/k (5D)
    QKCompute-->>Pipeline: q_5d, k_5d
  end
  par Self-Copy
    Pipeline->>SelfCopy: self CUDA memcpy (windowed)
  and Peer-Copy
    Pipeline->>Peers: peer memcpy (windowed)
  end
  Pipeline->>Barrier: emit() on pri_comm stream
  Barrier-->>Pipeline: release-ordered rank sync
  Pipeline->>Backend: forward(v_5d, q_5d, k_5d) inner SDPA
  Backend-->>Pipeline: sdpa_out_5d
  Pipeline->>Pipeline: post_permute (5D → 4D)
  Pipeline-->>Client: output tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.72% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The title clearly describes the main feature being added: an async Ulysses pipeline for asynchronous all-to-all communication, enabled for LTX-2 model. It aligns with the primary objective and changes throughout the PR.
Description check ✅ Passed The PR description is comprehensive and detailed. It explains the async Ulysses pipeline feature, design rationale, key optimizations, benchmarks, files changed, follow-ups, and test coverage.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tensorrt_llm/_torch/distributed/_ulysses_gc.py (1)

1-2: ⚡ Quick win

Update the copyright year on this new file.

This file is newly added in 2026, so the NVIDIA header should carry the current modification year as required by the repo rules.

As per coding guidelines "All C++, Python, and other source files must contain NVIDIA copyright header with current modification year".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/distributed/_ulysses_gc.py` around lines 1 - 2, The file
header comment currently shows "2025" as the copyright/modification year; update
the SPDX header at the top of tensorrt_llm/_torch/distributed/_ulysses_gc.py to
use the current year "2026" (both occurrences in the two header lines) so the
NVIDIA copyright header matches repo rules.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/tensorrt_llm/thop/alltoallOp.cpp`:
- Around line 401-418: The fast-path check of mLsaBarrier races with the
lazy-init under mLsaBarrierMutex; fix by synchronizing publication before
calling emit(): acquire mLsaBarrierMutex (or use std::call_once) to ensure
creation via tensorrt_llm::kernels::LsaBarrier::create completes and the pointer
is visible, then copy mLsaBarrier to a local variable, release the lock, check
it with TLLM_CHECK_WITH_INFO and call local->emit(stream). Ensure all uses of
mLsaBarrier->emit() use the locally-copied pointer so no emit runs on a
concurrently-written pointer.

In `@tensorrt_llm/_torch/distributed/_ulysses_gc.py`:
- Around line 264-277: The current broad "except Exception" around the
_create_two_partitions call hides unrelated bugs; replace it with catching only
the specific errors that indicate partition setup failure (e.g., RuntimeError
and any CUDA/OSError that _create_two_partitions can raise) and let all other
exceptions propagate. Concretely, change the except clause around
_create_two_partitions(...) to "except (RuntimeError, OSError) as e:" (or the
exact exception types _create_two_partitions documents), keep the warnings.warn
fallback logic, and do not swallow other exceptions so bugs in
_create_two_partitions or surrounding code fail loudly.
- Around line 291-297: The per-device singleton cache in
UlyssesPipelineStreams.get is racy; wrap the lookup/create path with a
class-level lock to ensure only one thread creates and publishes an instance for
a given device_id. Add a class attribute (e.g., _instances_lock =
threading.Lock()) on the UlyssesPipelineStreams class, initialize it alongside
_instances, and in the get(cls, device_id, ...) method acquire the lock before
checking cls._instances.get(device_id) and creating/storing a new instance, then
release the lock (use context manager) after the update so concurrent callers
cannot create duplicate instances.

---

Nitpick comments:
In `@tensorrt_llm/_torch/distributed/_ulysses_gc.py`:
- Around line 1-2: The file header comment currently shows "2025" as the
copyright/modification year; update the SPDX header at the top of
tensorrt_llm/_torch/distributed/_ulysses_gc.py to use the current year "2026"
(both occurrences in the two header lines) so the NVIDIA copyright header
matches repo rules.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 27c2f052-21cb-46e3-8eaf-41dbebc0318d

📥 Commits

Reviewing files that changed from the base of the PR and between 9547230 and 856e7fa.

📒 Files selected for processing (8)
  • cpp/tensorrt_llm/kernels/lsaBarrierKernel.cu
  • cpp/tensorrt_llm/kernels/lsaBarrierKernel.h
  • cpp/tensorrt_llm/thop/alltoallOp.cpp
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tensorrt_llm/_torch/distributed/_ulysses_gc.py
  • tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py

Comment thread cpp/tensorrt_llm/thop/alltoallOp.cpp Outdated
Comment thread tensorrt_llm/_torch/distributed/_ulysses_gc.py Outdated
Comment thread tensorrt_llm/_torch/distributed/_ulysses_gc.py Outdated
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47692 [ run ] completed with state SUCCESS. Commit: 856e7fa
/LLM/main/L0_MergeRequest_PR pipeline #37588 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@luyiyun1021 luyiyun1021 changed the title [None][feat] LTX2 Ulysses async A2A pipeline via NCCL window + LSA barrier [TRTLLM-11457][feat] LTX2 Ulysses async A2A pipeline via NCCL window + LSA barrier May 13, 2026
@luyiyun1021 luyiyun1021 changed the title [TRTLLM-11457][feat] LTX2 Ulysses async A2A pipeline via NCCL window + LSA barrier [TRTLLM-11457][feat] Async Ulysses Pipeline May 20, 2026
@luyiyun1021 luyiyun1021 changed the title [TRTLLM-11457][feat] Async Ulysses Pipeline [TRTLLM-11457][feat] Async Ulysses May 20, 2026
@luyiyun1021 luyiyun1021 requested a review from a team as a code owner May 20, 2026 12:26
@luyiyun1021 luyiyun1021 requested review from chang-l and venkywonka May 20, 2026 12:26
@luyiyun1021 luyiyun1021 changed the title [TRTLLM-11457][feat] Async Ulysses [TRTLLM-11457][feat] Async Ulysses pipeline (LTX-2 + WAN, NCCL window + LSA barrier) May 20, 2026
@luyiyun1021 luyiyun1021 force-pushed the dev-ltx2-ulysses-async-a2a-pipeline branch 3 times, most recently from 8ab04d2 to 881d06d Compare May 21, 2026 08:47
@luyiyun1021 luyiyun1021 requested review from NVShreyas and zhenhuaw-me and removed request for a team, dongxuy04, venkywonka and yizhang-nv May 21, 2026 08:52
@luyiyun1021 luyiyun1021 changed the title [TRTLLM-11457][feat] Async Ulysses pipeline (LTX-2 + WAN, NCCL window + LSA barrier) [TRTLLM-11457][feat] Async Ulysses pipeline (LTX-2 + WAN, PyTorch _SymmetricMemory CUDA-IPC) May 21, 2026
@luyiyun1021 luyiyun1021 force-pushed the dev-ltx2-ulysses-async-a2a-pipeline branch from a3933c6 to 6585bc6 Compare May 21, 2026 16:01
@luyiyun1021 luyiyun1021 force-pushed the dev-ltx2-ulysses-async-a2a-pipeline branch from 7f6e7e6 to b5a86a9 Compare June 6, 2026 09:28
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52496 [ run ] triggered by Bot. Commit: b5a86a9 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52468 [ run ] completed with state ABORTED. Commit: 7f6e7e6

Link to invocation

Copy link
Copy Markdown
Member

@zhenhuaw-me zhenhuaw-me left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments about the non-kernel part. Feel free to address in the follow up PRs if applicable to unblock the PR merging as-is.

status="prototype",
description=("Ulysses head-sharding degree. Heads are sharded across ulysses_size GPUs."),
)
async_ulysses: bool = Field(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this feature is always on if Ulysses is enabled since it should not have any perf penalty. Did I miss any case? If that's the case, we don't need a knob.

Feel free to resolve this comment to unblock the PR merging. If we converged that we don't need the knob, we can address in upcoming PRs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flux and cosmos not supported yet. Since different models has different attn module computation so we have to customize the pipeline. This may need some extra work we may leave it to future pr.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No perf penalty in our test cases but since currently we have cuda graph perf issues with cudaMemcpyBatchAsync. I think we may leave some flexibilty in case there comes perf penalty. How do you think?

Comment thread tensorrt_llm/_torch/visual_gen/modules/attention.py Outdated
Comment thread tensorrt_llm/_torch/visual_gen/modules/attention.py Outdated
@luyiyun1021 luyiyun1021 force-pushed the dev-ltx2-ulysses-async-a2a-pipeline branch from b5a86a9 to a861c81 Compare June 6, 2026 14:02
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52513 [ run ] triggered by Bot. Commit: a861c81 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52496 [ run ] completed with state ABORTED. Commit: b5a86a9

Link to invocation

…ses=True conflict

Pre-fix, async_ulysses=True forces qkv_mode=SEPARATE_QKV (so V/Q/K
projections can stream independently), which bypassed the FUSE_QKV-only
Ring gate (`qkv_mode != SEPARATE_QKV`) — Ring was silently disabled.

Even at the wrap level it would not actually work end-to-end:
`UlyssesAttention.forward_async` dispatches straight to
`self.inner_backend.forward_async` and never invokes the wrapped
`RingAttention.forward`, so Ring would be no-op under async even if the
wrap had been allowed. Real Ring+async support requires
`RingAttention.forward_async` + pass-through in
`UlyssesAttention.forward_async` — out of scope for this PR.

For now, raise ValueError on the conflicting combo so users opting into
both don't silently lose Ring. Reported in PR review by NVShreyas.

Signed-off-by: Yiyun Lu <[email protected]>
…IPC constraint

Documents the async Ulysses A2A pipeline (`async_ulysses` in
`ParallelConfig`) added by this PR series across three places:

* docs/source/models/visual-generation.md — adds a sub-bullet under
  Ulysses Parallelism in the Multi-GPU section, noting the requirement
  on an NVLink-connected domain (PyTorch `_SymmetricMemory` + CUDA IPC
  for peer pushes; not currently supported across nodes without MNNVL).

* examples/visual_gen/README.md — mirrors the note under WAN's
  Multi-GPU Parallelism section and adds a new Multi-GPU Parallelism
  sub-section under LTX-2 (LTX-2 supports head-sharded Ulysses with the
  same divisibility constraint and benefits from async_ulysses the same
  way).

* examples/visual_gen/configs/wan2.2-t2v-fp4-4gpu.yaml — opts the
  4-GPU WAN reference serve config into `async_ulysses: true` (NVLink
  domain on a single 8-GPU node, the standard target for this config).

Uses the field name as it appears in code (`async_ulysses`); reviewer
flagged that the PR description still references the older
`dit_async_ulysses` name.

Signed-off-by: Yiyun Lu <[email protected]>
…lysses=true

Mirrors the wan2.2-t2v-fp4-4gpu.yaml structure but for LTX-2: cfg_size=2
+ ulysses_size=2 + async_ulysses=true. Precision-agnostic (no quant_config
block; checkpoint precision is selected via --model_path).

Signed-off-by: Yiyun Lu <[email protected]>
… enable torch_compile + cuda_graph in LTX-2 4-GPU config

* wan2.2-t2v-fp4-4gpu.yaml and ltx2-4gpu.yaml previously carried a
  3-line comment under `parallel_config` explaining the MNNVL/CUDA IPC
  requirement of `async_ulysses`. The same explanation lives in
  docs/source/models/visual-generation.md and examples/visual_gen/README.md
  already, so the inline yaml comment was redundant — dropped from both.

* ltx2-4gpu.yaml: enable torch_compile and cuda_graph (`enable: true`
  for both). These are standard prod knobs for the 4-GPU LTX-2 path;
  the WAN 4-GPU config is left as-is (`cuda_graph_config.enable: false`).

Signed-off-by: Yiyun Lu <[email protected]>
Address Codex/reviewer findings on the async Ulysses A2A pipeline.
No functional change to the hot path; existing ws=2 bit-exact parity
tests are unaffected.

C++ transport plane (asyncUlyssesOp.cpp):
- ulysses_a2a_async_prepare: reject non-CUDA input and install
  c10::cuda::CUDAGuard so slot allocation (cudaGetDevice) and kernel
  launch (input.get_device()) bind to the same device.
- SendHandle: add group_name field; ulysses_a2a_async validates
  handle.group_name == pg.getGroupName() to reject cross-PG handle
  reuse (two PGs of the same size would otherwise pass the
  peer-pointer-count check and silently push into the wrong group).
- AsyncUlyssesOp::getOrAllocSlot: build new slot state in local
  variables; commit (move-assign + free old sendBuf) only after all
  four allocation steps (symm-mem empty_strided_p2p + rendezvous +
  get_buffer_ptrs + cudaMalloc) succeed. Prior progressive mutation
  left a poisoned cached slot on partial-failure paths.

C++ op bindings:
- ulyssesPostUnscatterOp.cpp: add TORCH_CHECK(D % 8 == 0) at op level
  (mirrors sibling ulysses_permute_scatter); early-return on empty
  Q/K/V to avoid zero-grid kernel launches.
- ulyssesPermuteScatterOp.cpp: add TORCH_CHECK(P > 0) before H % P to
  avoid division-by-zero / SIGFPE on bad schema-level callers;
  early-return on empty input.

Python:
- modules/attention.py::Attention.forward_async: prescriptive
  ValueError when self.attn lacks forward_async (caller used
  async_ulysses=False at init), replacing the deep AttributeError.

Signed-off-by: Yiyun Lu <[email protected]>
…o NHD layout

The fused post-A2A unscatter kernel previously only emitted HND
[B, H, P*Sp, D], so async Ulysses with TRTLLM / FA4 backends (which
prefer NHD) fell back to the eager 3-copy permute path. Extend the
kernel and op binding to also emit NHD [B, P*Sp, H, D] in one launch,
so all bf16 async-Ulysses callers benefit regardless of backend
layout. HND fast-path callers are unaffected (default layout=0).

Changes:
- ulyssesPostUnscatterKernel.{cu,h}: add `bool IsHnd` template arg;
  if constexpr branch on out_base. Launcher dispatches two template
  instances based on a new `bool is_hnd` runtime arg.
- ulyssesPostUnscatterOp.cpp: schema adds `int layout=0` (0=HND,
  1=NHD); op validates layout, allocates outputs with the matching
  shape, and passes is_hnd to the kernel launcher. Default keeps
  backward compatibility with existing callers.
- cpp_custom_ops.py: register_fake takes the new layout=0 default
  and branches the fake output shape accordingly.
- visual_gen/attention_backend/parallel.py:
  - Rename helper _ulysses_post_unscatter_to_hnd to
    _ulysses_post_unscatter with is_hnd kwarg.
  - Rename gate flag use_fused_op to use_fused_post_unscatter and
    drop the HND-only condition; the fused kernel now covers both
    layouts so the gate is just bf16.
  - Eager fallback (non-bf16) keeps the existing NHD/HND post-permute
    logic unchanged.
- test_ulysses_post_unscatter.py: parametrize exact_match test on
  layout (HND + NHD); add an explicit reject test for invalid layout
  values; reference function branches on is_hnd to skip the final
  transpose for NHD.

Signed-off-by: Yiyun Lu <[email protected]>
… transpose-view

cudnn SDPA preserves the input's stride pattern in its output. The sync
`_forward_unfused` path passes HND-shape NHD-stride tensors (via
`q.transpose(1, 2)` without `.contiguous()`), so cudnn returns
HND-shape NHD-stride output and the downstream
`_output_a2a.transpose(1, 2).contiguous()` collapses to a no-op.

The async `forward_async` path used `_ulysses_post_unscatter(is_hnd=True)`
to allocate HND-contig storage. cudnn then returned HND-contig output, and
`_output_a2a`'s transpose+contiguous required a real memcpy — observable
in nsys as a 62us `triton_poi_fused_all_to_all_single_clone_permute_transpose_view_0`
kernel between SDPA and reverse NCCL (absent in OFF). At WAN A14B
720x1280/81f this is ~5 ms / step of avoidable layout cost.

Make the post-unscatter kernel always write NHD-contig storage
`[B, P*Sp, H, D]`; the op wrapper returns it as-is for NHD callers and as
a transpose-view `[B, H, P*Sp, D]` for HND callers (HND-shape, NHD-stride,
non-contig — mirrors the sync path). cudnn then preserves NHD-stride
through SDPA and the post-attention contiguous() is free.

- Kernel: drop `IsHnd` template param + `if constexpr` branch, single
  NHD output address calculation.
- Op: always alloc NHD storage; return `.transpose(1, 2)` view for layout=0.
- Fake op: mirror via `new_empty(NHD).transpose(1, 2)` so Inductor sees
  matching strides.
- Test: update contiguity assertion (HND output is now non-contig view).
  `max_diff == 0` exact-match still holds.

Signed-off-by: Yiyun Lu <[email protected]>
…self-attn

Under async ulysses, video self-attn forces ``QKVMode.SEPARATE_QKV`` so
the three projections (to_q/to_k/to_v) can stream-pipeline through the
A2A. With NVFP4 static quant this triggers three identical
``tunable_fp4_quantize`` launches on the same hidden_states (verified
bit-equal input_scale across all WAN-2.2 high-noise 36 / low-noise 36
and LTX-2 single-stage 42 self-attn layers).

Pre-quantize the shared input once and pass the resulting
``Fp4QuantizedTensor`` to each Linear via the existing fp4-shortcut
path in NVFP4LinearMethod._input_prepare. Two of three quantize launches
per self-attn block per layer are eliminated (verified -368 launches in
the WAN A14B 8-GPU step capture; identical kernel count between
async-OFF FUSE_QKV and async-ON SEPARATE_QKV in LTX-2 prod-shape).

Gating:
  - Structural at __init__: SEPARATE_QKV, NVFP4 layer mode, no
    force_dynamic_quantization. Stored on ``_maybe_share_qkv_quantize``.
  - Runtime on the forward path: ``getattr(to_q, "input_scale", None) is
    not None`` -- guards against checkpoints that exclude individual
    Linears from NVFP4 (e.g. LTX-2 transformer_blocks.10.attn1, where
    to_q/to_k/to_v are bf16-only and never load input_scale).

Applied in the two async-self-attn paths:
  - ``Attention.forward_async`` (base async self-attn path used by WAN).
  - ``LTX2Attention.forward_async`` (LTX-2 has its own override; mirrors
    the base path's pre-quantize + share contract).

Signed-off-by: Yiyun Lu <[email protected]>
…oin_async

The original ``ulysses_a2a_async`` op did the CE peer push and the
symm-mem barrier back-to-back on the side stream, so V/Q/K each landed
``[push, barrier, push, barrier, push, barrier]`` on the comm stream.
Each barrier serialized with the next push, even though no cross-rank
ordering is required between issues -- only the final wait before SDPA
reads the recv buffer needs the fence.

Split the op into ``ulysses_a2a_async_push`` (CE push only) and
``ulysses_a2a_async_barrier`` (emit barrier on channel 0), keep the
original ``ulysses_a2a_async`` as a thin push+barrier alias for
backward compatibility, and rewire ``UlyssesAttention``:

  - ``_issue_async`` calls ``_push`` and bumps ``_pending_barriers``.
  - ``_join_async`` drains exactly ``_pending_barriers`` barriers on the
    side stream and then records the tail event for the default-stream
    wait.

Comm-stream FIFO preserves [push V, push Q, push K, barrier, barrier,
barrier]; channel-0 barriers all have identical semantics, so the
default-stream wait sees a fully-synced recv buffer. WAN 8-GPU step
capture shows the per-call barrier_kernel time drops from ~164us to
~93us (-43%) and consecutive V/Q/K pushes run through CE without the
intervening fence kernel cutting them apart (PTOP throughput -4ms /
step). End-to-end wall is unchanged in WAN because barrier and push
already sat on the side stream; the win is on energy / GPU work and
opens room for follow-up overlap tuning.

Signed-off-by: Yiyun Lu <[email protected]>
…Kernel

The original SPDX-only headers (`SPDX-FileCopyrightText` + `SPDX-License-Identifier`) did not parse cleanly under the `license_checker` used by the C++ CI stage, which expects the full Apache-2.0 boilerplate that the rest of `cpp/tensorrt_llm/kernels/` already uses. Switch the two ulyssesPermuteScatterKernel files to that format; no code change.

Signed-off-by: Yiyun Lu <[email protected]>
… rename async_pipeline kwarg to async_ulysses for consistency

The user-facing ParallelConfig flag, Attention.__init__, LTX2Attention, and WanBlock all use async_ulysses. UlyssesAttention's constructor and wrap_parallel_attention used async_pipeline (introduced by the Option B refactor in the symm-mem CUDA-IPC commit). The split was theoretically motivated (wrapper-internal vs user-feature) but in practice every caller mapped 1:1, and the divergent name hurt grep + reviewer comprehension. This renames the 5 occurrences for consistency.

Signed-off-by: Yiyun Lu <[email protected]>
@luyiyun1021 luyiyun1021 force-pushed the dev-ltx2-ulysses-async-a2a-pipeline branch from a861c81 to 0b03cb6 Compare June 6, 2026 14:22
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52515 [ run ] triggered by Bot. Commit: 0b03cb6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52513 [ run ] completed with state ABORTED. Commit: a861c81

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52515 [ run ] completed with state FAILURE. Commit: 0b03cb6
/LLM/main/L0_MergeRequest_PR pipeline #41805 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52540 [ run ] triggered by Bot. Commit: 0b03cb6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52540 [ run ] completed with state SUCCESS. Commit: 0b03cb6
/LLM/main/L0_MergeRequest_PR pipeline #41825 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52576 [ run ] triggered by Bot. Commit: 0b03cb6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52576 [ run ] completed with state SUCCESS. Commit: 0b03cb6
/LLM/main/L0_MergeRequest_PR pipeline #41858 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants