Skip to content

Video as prompt#1034

Merged
Artiprocher merged 7 commits intomainfrom
video_as_prompt
Nov 4, 2025
Merged

Video as prompt#1034
Artiprocher merged 7 commits intomainfrom
video_as_prompt

Conversation

@Artiprocher
Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Artiprocher, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the video generation capabilities by integrating the "Video as Prompt" (VAP) feature. It introduces a new model, ByteDance/Video-As-Prompt-Wan2.1-14B, which allows for generating videos where the motion and style are guided by a reference video and accompanying textual prompts. The changes involve extending the core model architecture, updating the video generation pipeline to handle VAP inputs, and providing comprehensive examples to facilitate its use for both inference and training.

Highlights

  • New Model Integration: Added support for the ByteDance/Video-As-Prompt-Wan2.1-14B model, which enables motion generation conditioned on reference videos.
  • Video-as-Prompt (VAP) Functionality: Introduced core VAP capabilities, allowing users to provide a reference video and associated prompts to guide the motion and content of generated videos.
  • Pipeline Enhancements: Integrated the VAP model into the WanVideoPipeline with a new WanVideoUnit_VAP for processing VAP inputs (prompts, CLIP embeddings, latents) and modifying the diffusion process.
  • Model Architecture Extension: Implemented MotWanModel and MotWanAttentionBlock to handle motion-specific self-attention and cross-attention within the DiT architecture, including specialized positional embeddings.
  • Comprehensive Examples: Provided new inference, full training, LoRA training, and validation scripts for the Video-As-Prompt-Wan2.1-14B model, demonstrating its usage.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Artiprocher Artiprocher merged commit 8332ece into main Nov 4, 2025
@Artiprocher Artiprocher deleted the video_as_prompt branch November 4, 2025 09:33
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for "Video as Prompt" using the ByteDance/Video-As-Prompt-Wan2.1-14B model. It adds a new model MotWanModel to handle motion transfer from a reference video, integrates it into the WanVideoPipeline, and provides corresponding examples for inference and training. The changes are extensive and well-structured. My review focuses on improving code maintainability by addressing code duplication, redundant logic, and adherence to Python style conventions.

Comment on lines +181 to +278
def from_diffusers(self, state_dict):

rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name}
if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7':
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
else:
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}

state_dict_ = {}

for name, param in state_dict.items():
name = name.replace("_mot_ref", "")
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
if name.split(".")[1].isdigit():
block_id = int(name.split(".")[1])
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param

if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B
config = {
"mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"num_heads": 40,
"eps": 1e-6
}
else:
config = {}
return state_dict_, config
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The rename_dict in this method is very similar to the one in WanModelStateDictConverter in diffsynth/models/wan_video_dit.py. This significant code duplication can make maintenance harder. Consider refactoring this into a shared utility function or class to improve code reuse and maintainability.

Comment on lines +241 to +244
if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7':
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
else:
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block assigns the same value to mot_layers in both branches. It can be simplified to a single assignment.

        mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)

return state_dict_, config



No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a final newline. It's a good practice to end files with a newline character.

pipe.load_models_to_device(["vae", "image_encoder"])
vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image")

num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable mot_num is assigned but never used. It should be removed to improve code clarity.

Suggested change
num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1)
num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width")

Comment on lines 1567 to 1601
for block_id, block in enumerate(dit.blocks):
# Block
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
if vap is not None and block_id in vap.mot_layers_mapping:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x, x_vap = torch.utils.checkpoint.checkpoint(
create_custom_forward_vap(block, vap),
x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
use_reentrant=False,
)
else:
x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
else:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
else:
x = block(x, context, t_mod, freqs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling gradient checkpointing is duplicated for both the VAP and non-VAP paths. This can be refactored to reduce code duplication and improve readability. You could determine the function to execute (vap or block) and its arguments first, then apply the checkpointing logic once.

        for block_id, block in enumerate(dit.blocks):
            # Block
            use_vap = vap is not None and block_id in vap.mot_layers_mapping

            if use_vap:
                forward_fn = create_custom_forward_vap(block, vap)
                forward_args = (x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
            else:
                forward_fn = create_custom_forward(block)
                forward_args = (x, context, t_mod, freqs)

            if use_gradient_checkpointing_offload:
                with torch.autograd.graph.save_on_cpu():
                    output = torch.utils.checkpoint.checkpoint(
                        forward_fn, *forward_args, use_reentrant=False
                    )
            elif use_gradient_checkpointing:
                output = torch.utils.checkpoint.checkpoint(
                    forward_fn, *forward_args, use_reentrant=False
                )
            else:
                output = forward_fn(*forward_args)

            if use_vap:
                x, x_vap = output
            else:
                x = output

Comment on lines +9 to +24
def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]:
if len(video_frames) == 0:
return []
if mode == "first":
return video_frames[:num]
if mode == "evenly":
import torch as _torch
idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist()
return [video_frames[i] for i in idx]
if mode == "random":
if len(video_frames) <= num:
return video_frames
import random as _random
start = _random.randint(0, len(video_frames) - num)
return video_frames[start:start+num]
return video_frames
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Imports should be at the top of the file, not inside functions, to adhere to PEP 8. Please move import torch as _torch and import random as _random to the top of the file. This improves readability and avoids repeated imports.

def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]:
    if len(video_frames) == 0:
        return []
    if mode == "first":
        return video_frames[:num]
    if mode == "evenly":
        idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist()
        return [video_frames[i] for i in idx]
    if mode == "random":
        if len(video_frames) <= num:
            return video_frames
        start = _random.randint(0, len(video_frames) - num)
        return video_frames[start:start+num]
    return video_frames

yxbian23 added a commit to bytedance/Video-As-Prompt that referenced this pull request Jan 31, 2026
- Add DiffSynth-Studio integration news ([PR #1034](modelscope/DiffSynth-Studio#1034))
- Credit @Artiprocher in acknowledgements
LePao1 pushed a commit to LePao1/DiffSynth-Studio that referenced this pull request Feb 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants