[TRTLLM-11408][feat] Add VisualGen TP Support#13614
Conversation
9266296 to
8c39ab3
Compare
📝 WalkthroughWalkthroughThis PR adds comprehensive tensor-parallel (TP) distributed training support to the visual generation model pipeline. It introduces CLI configuration for TP group size, updates distributed AllReduce operations to use local process groups, implements a new multi-backend RMSNorm module with quantization support, updates device mesh topology, and applies TP-aware configurations throughout the attention and transformer layers. ChangesTensor-Parallel Visual Generation Training
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/visual_gen/pipeline_loader.py (1)
120-130:⚠️ Potential issue | 🟠 Major | ⚡ Quick winInitialize
ws/rkbefore theargsbranch.If
PipelineLoaderis created withoutargson a distributed run, line 126 readswsbefore it is assigned, so_setup_visual_gen_mapping()crashes withUnboundLocalErrorinstead of taking the fallback path.Suggested fix
def _setup_visual_gen_mapping(self, config: DiffusionModelConfig) -> None: + ws = dist.get_world_size() if dist.is_initialized() else 1 + rk = dist.get_rank() if dist.is_initialized() else 0 if self.args is not None: - ws = dist.get_world_size() if dist.is_initialized() else 1 - rk = dist.get_rank() if dist.is_initialized() else 0 vgm = VisualGenMapping( ws, rk,🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/visual_gen/pipeline_loader.py` around lines 120 - 130, _pipeline_loader's _setup_visual_gen_mapping uses local variables ws/rk only set inside the args branch, causing UnboundLocalError when args is None on a distributed run; initialize ws and rk (or derive from dist if initialized) before the args conditional so the fallback path can read them. Update PipelineLoader._setup_visual_gen_mapping to set ws and rk to sensible defaults at top (e.g., ws=1, rk=0) and, if dist.is_initialized(), override from torch.distributed (or compute from args when present) before creating VisualGenMapping and calling init_pg; ensure you reference VisualGenMapping, config.visual_gen_mapping, config.mapping, and init_pg in the fix so those uses no longer read uninitialized ws/rk.tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py (1)
289-299:⚠️ Potential issue | 🟠 Major | ⚡ Quick winThis change disables WAN sequence-parallel self-attention.
WanTransformer3DModelstill shards the sequence for Ulysses/Attention2D, butAttentiononly wraps those backends whenqkv_mode != SEPARATE_QKV. After this switch, each rank attends only to its local chunk, so existingulysses_size > 1and Attention2D runs regress even whentp_size == 1.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py` around lines 289 - 299, The current change constructs Attention with qkv_mode=QKVMode.SEPARATE_QKV which prevents Attention from wrapping sequence-parallel backends (Ulysses/Attention2D), causing each rank to only attend its local chunk when WanTransformer3DModel shards sequence; modify the Attention construction so it still enables/wraps sequence-parallel backends when the model is configured for sequence sharding: either remove the conditional that skips those backends for QKVMode.SEPARATE_QKV in Attention, or add an explicit flag (e.g., enable_sequence_parallel=True) when creating Attention in WanTransformer3DModel so that Attention will use Ulysses/Attention2D whenever ulysses_size>1 (or model_config indicates sequence-parallel) regardless of qkv_mode.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/distributed/ops.py`:
- Around line 780-782: The init assigns self.tp_group and self.pg only inside a
narrow branch, causing forward() to sometimes hit AttributeError;
unconditionally cache both group attributes during construction by always
setting self.tp_group = self.mapping.tp_group and self.pg =
getattr(self.mapping, "tp_group_pg", None) (or assigning mapping.tp_group_pg
without condition) so forward() can safely reference them; retain existing logic
that checks self._disable_mpi and self.mnnvl_allreduce when choosing all-reduce
behavior, and apply the same unconditional assignment fix for the other
constructor section covering the code around the forward-related logic (the
region referenced by 880-918) so both branches always define tp_group and pg.
In `@tensorrt_llm/_torch/visual_gen/config.py`:
- Around line 145-146: The n_workers method currently returns only dit_cfg_size
* dit_ulysses_size * dit_tp_size and omits attention2D/ring (CP) factors; update
n_workers to align with total_parallel_size (which already includes
attn2d_row_size/attn2d_col_size or other CP dimensions) by returning or
delegating to total_parallel_size so supported configs like attn2d_row_size=2,
attn2d_col_size=2 are counted correctly.
In `@tensorrt_llm/_torch/visual_gen/mapping.py`:
- Around line 141-147: The dims tuple currently appends a trailing "cp" even
though self._dim_names already contains "cp", producing duplicate names and
breaking device mesh lookups; change the construction of dims and the
corresponding shape before calling init_device_mesh (used when setting
cls.device_mesh) to avoid adding a duplicate "cp" — e.g., build dims as ("pp",)
+ tuple(self._dim_names) and ensure shape mirrors dims (use (1,) +
tuple(self._dim_sizes[d] for d in self._dim_names) + (1,) only if the trailing
wrapper dim is actually not already in self._dim_names), or conditionally append
the wrapper "cp" and its size only when it is missing from self._dim_names so
init_device_mesh receives matching names and sizes.
---
Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py`:
- Around line 289-299: The current change constructs Attention with
qkv_mode=QKVMode.SEPARATE_QKV which prevents Attention from wrapping
sequence-parallel backends (Ulysses/Attention2D), causing each rank to only
attend its local chunk when WanTransformer3DModel shards sequence; modify the
Attention construction so it still enables/wraps sequence-parallel backends when
the model is configured for sequence sharding: either remove the conditional
that skips those backends for QKVMode.SEPARATE_QKV in Attention, or add an
explicit flag (e.g., enable_sequence_parallel=True) when creating Attention in
WanTransformer3DModel so that Attention will use Ulysses/Attention2D whenever
ulysses_size>1 (or model_config indicates sequence-parallel) regardless of
qkv_mode.
In `@tensorrt_llm/_torch/visual_gen/pipeline_loader.py`:
- Around line 120-130: _pipeline_loader's _setup_visual_gen_mapping uses local
variables ws/rk only set inside the args branch, causing UnboundLocalError when
args is None on a distributed run; initialize ws and rk (or derive from dist if
initialized) before the args conditional so the fallback path can read them.
Update PipelineLoader._setup_visual_gen_mapping to set ws and rk to sensible
defaults at top (e.g., ws=1, rk=0) and, if dist.is_initialized(), override from
torch.distributed (or compute from args when present) before creating
VisualGenMapping and calling init_pg; ensure you reference VisualGenMapping,
config.visual_gen_mapping, config.mapping, and init_pg in the fix so those uses
no longer read uninitialized ws/rk.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 516f2b2f-6ee2-4085-a514-1954b64c0475
📒 Files selected for processing (8)
examples/visual_gen/visual_gen_wan_t2v.pytensorrt_llm/_torch/distributed/ops.pytensorrt_llm/_torch/visual_gen/config.pytensorrt_llm/_torch/visual_gen/mapping.pytensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytensorrt_llm/_torch/visual_gen/modules/attention.pytensorrt_llm/_torch/visual_gen/modules/rms_norm.pytensorrt_llm/_torch/visual_gen/pipeline_loader.py
99e0768 to
657ab61
Compare
c0a414a to
557f483
Compare
7b747ef to
447fac6
Compare
|
@karljang / @yibinl-nvidia could you guys please review from FLUX / LTX perspective? cc @chang-l for review |
karljang
left a comment
There was a problem hiding this comment.
Thank you for the PR.
My AI suggests that we should test the TP=1 case. Could you please take a look?
|
PR_Github #50869 [ run ] completed with state |
|
PR_Github #50884 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #50886 [ run ] triggered by Bot. Commit: |
|
PR_Github #50886 [ run ] completed with state
|
Signed-off-by: belgarten-nv <[email protected]>
|
/bot run --disable-fail-fast |
|
PR_Github #50985 [ run ] triggered by Bot. Commit: |
|
PR_Github #50985 [ run ] completed with state
|
|
PR_Github #51061 [ kill ] triggered by Bot. Commit: |
|
PR_Github #51061 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #51062 [ run ] triggered by Bot. Commit: |
|
PR_Github #51062 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #51120 [ run ] triggered by Bot. Commit: |
|
PR_Github #51120 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #51130 [ run ] triggered by Bot. Commit: |
|
PR_Github #51130 [ run ] completed with state
|
|
/bot skip --comment "2 failing tests from AutoDeploy. Both spurious failures that are unrelated to code changes. Will submit separate MR to waive them." |
|
PR_Github #51207 [ skip ] triggered by Bot. Commit: |
|
PR_Github #51207 [ skip ] completed with state |
|
/bot skip --comment "2 failing tests from AutoDeploy. Both spurious failures that are unrelated to code changes. Will submit separate MR to waive them." |
|
PR_Github #51210 [ skip ] triggered by Bot. Commit: |
|
PR_Github #51210 [ skip ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Refactor
Description
torch.compilearounddist.AllReduceby factoring out device mesh slicing to constructor.--tp_sizeflag to VisualGen example scripts.Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.