Conversation
Removed dataset snapshot download function call.
Summary of ChangesHello @Artiprocher, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the video generation capabilities by integrating the "Video as Prompt" (VAP) feature. It introduces a new model, ByteDance/Video-As-Prompt-Wan2.1-14B, which allows for generating videos where the motion and style are guided by a reference video and accompanying textual prompts. The changes involve extending the core model architecture, updating the video generation pipeline to handle VAP inputs, and providing comprehensive examples to facilitate its use for both inference and training. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for "Video as Prompt" using the ByteDance/Video-As-Prompt-Wan2.1-14B model. It adds a new model MotWanModel to handle motion transfer from a reference video, integrates it into the WanVideoPipeline, and provides corresponding examples for inference and training. The changes are extensive and well-structured. My review focuses on improving code maintainability by addressing code duplication, redundant logic, and adherence to Python style conventions.
| def from_diffusers(self, state_dict): | ||
|
|
||
| rename_dict = { | ||
| "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", | ||
| "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", | ||
| "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", | ||
| "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", | ||
| "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", | ||
| "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", | ||
| "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", | ||
| "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", | ||
| "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", | ||
| "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", | ||
| "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", | ||
| "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", | ||
| "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", | ||
| "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", | ||
| "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", | ||
| "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", | ||
| "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", | ||
| "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", | ||
| "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", | ||
| "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", | ||
| "blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias", | ||
| "blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight", | ||
| "blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias", | ||
| "blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight", | ||
| "blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight", | ||
| "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", | ||
| "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", | ||
| "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", | ||
| "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", | ||
| "blocks.0.norm2.bias": "blocks.0.norm3.bias", | ||
| "blocks.0.norm2.weight": "blocks.0.norm3.weight", | ||
| "blocks.0.scale_shift_table": "blocks.0.modulation", | ||
| "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", | ||
| "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", | ||
| "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", | ||
| "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", | ||
| "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", | ||
| "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", | ||
| "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", | ||
| "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", | ||
| "condition_embedder.time_proj.bias": "time_projection.1.bias", | ||
| "condition_embedder.time_proj.weight": "time_projection.1.weight", | ||
| "condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias", | ||
| "condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight", | ||
| "condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias", | ||
| "condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight", | ||
| "condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias", | ||
| "condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight", | ||
| "condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias", | ||
| "condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight", | ||
| "patch_embedding.bias": "patch_embedding.bias", | ||
| "patch_embedding.weight": "patch_embedding.weight", | ||
| "scale_shift_table": "head.modulation", | ||
| "proj_out.bias": "head.head.bias", | ||
| "proj_out.weight": "head.head.weight", | ||
| } | ||
| state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name} | ||
| if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7': | ||
| mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) | ||
| else: | ||
| mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) | ||
| mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)} | ||
|
|
||
| state_dict_ = {} | ||
|
|
||
| for name, param in state_dict.items(): | ||
| name = name.replace("_mot_ref", "") | ||
| if name in rename_dict: | ||
| state_dict_[rename_dict[name]] = param | ||
| else: | ||
| if name.split(".")[1].isdigit(): | ||
| block_id = int(name.split(".")[1]) | ||
| name = name.replace(str(block_id), str(mot_layers_mapping[block_id])) | ||
| name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) | ||
| if name_ in rename_dict: | ||
| name_ = rename_dict[name_] | ||
| name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) | ||
| state_dict_[name_] = param | ||
|
|
||
| if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B | ||
| config = { | ||
| "mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36), | ||
| "has_image_input": True, | ||
| "patch_size": [1, 2, 2], | ||
| "in_dim": 36, | ||
| "dim": 5120, | ||
| "ffn_dim": 13824, | ||
| "freq_dim": 256, | ||
| "text_dim": 4096, | ||
| "num_heads": 40, | ||
| "eps": 1e-6 | ||
| } | ||
| else: | ||
| config = {} | ||
| return state_dict_, config |
There was a problem hiding this comment.
| if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7': | ||
| mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) | ||
| else: | ||
| mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36) |
| return state_dict_, config | ||
|
|
||
|
|
||
|
No newline at end of file |
| pipe.load_models_to_device(["vae", "image_encoder"]) | ||
| vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image") | ||
|
|
||
| num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1) |
There was a problem hiding this comment.
The variable mot_num is assigned but never used. It should be removed to improve code clarity.
| num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1) | |
| num_frames, height, width = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width") |
| for block_id, block in enumerate(dit.blocks): | ||
| # Block | ||
| if use_gradient_checkpointing_offload: | ||
| with torch.autograd.graph.save_on_cpu(): | ||
| if vap is not None and block_id in vap.mot_layers_mapping: | ||
| if use_gradient_checkpointing_offload: | ||
| with torch.autograd.graph.save_on_cpu(): | ||
| x, x_vap = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward_vap(block, vap), | ||
| x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, | ||
| use_reentrant=False, | ||
| ) | ||
| elif use_gradient_checkpointing: | ||
| x, x_vap = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward_vap(block, vap), | ||
| x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, | ||
| use_reentrant=False, | ||
| ) | ||
| else: | ||
| x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) | ||
| else: | ||
| if use_gradient_checkpointing_offload: | ||
| with torch.autograd.graph.save_on_cpu(): | ||
| x = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(block), | ||
| x, context, t_mod, freqs, | ||
| use_reentrant=False, | ||
| ) | ||
| elif use_gradient_checkpointing: | ||
| x = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(block), | ||
| x, context, t_mod, freqs, | ||
| use_reentrant=False, | ||
| ) | ||
| elif use_gradient_checkpointing: | ||
| x = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(block), | ||
| x, context, t_mod, freqs, | ||
| use_reentrant=False, | ||
| ) | ||
| else: | ||
| x = block(x, context, t_mod, freqs) | ||
| else: | ||
| x = block(x, context, t_mod, freqs) | ||
|
|
There was a problem hiding this comment.
The logic for handling gradient checkpointing is duplicated for both the VAP and non-VAP paths. This can be refactored to reduce code duplication and improve readability. You could determine the function to execute (vap or block) and its arguments first, then apply the checkpointing logic once.
for block_id, block in enumerate(dit.blocks):
# Block
use_vap = vap is not None and block_id in vap.mot_layers_mapping
if use_vap:
forward_fn = create_custom_forward_vap(block, vap)
forward_args = (x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
else:
forward_fn = create_custom_forward(block)
forward_args = (x, context, t_mod, freqs)
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
output = torch.utils.checkpoint.checkpoint(
forward_fn, *forward_args, use_reentrant=False
)
elif use_gradient_checkpointing:
output = torch.utils.checkpoint.checkpoint(
forward_fn, *forward_args, use_reentrant=False
)
else:
output = forward_fn(*forward_args)
if use_vap:
x, x_vap = output
else:
x = output| def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]: | ||
| if len(video_frames) == 0: | ||
| return [] | ||
| if mode == "first": | ||
| return video_frames[:num] | ||
| if mode == "evenly": | ||
| import torch as _torch | ||
| idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist() | ||
| return [video_frames[i] for i in idx] | ||
| if mode == "random": | ||
| if len(video_frames) <= num: | ||
| return video_frames | ||
| import random as _random | ||
| start = _random.randint(0, len(video_frames) - num) | ||
| return video_frames[start:start+num] | ||
| return video_frames |
There was a problem hiding this comment.
Imports should be at the top of the file, not inside functions, to adhere to PEP 8. Please move import torch as _torch and import random as _random to the top of the file. This improves readability and avoids repeated imports.
def select_frames(video_frames: List[PIL.Image.Image], num: int, mode: str) -> List[PIL.Image.Image]:
if len(video_frames) == 0:
return []
if mode == "first":
return video_frames[:num]
if mode == "evenly":
idx = _torch.linspace(0, len(video_frames) - 1, num).long().tolist()
return [video_frames[i] for i in idx]
if mode == "random":
if len(video_frames) <= num:
return video_frames
start = _random.randint(0, len(video_frames) - num)
return video_frames[start:start+num]
return video_frames- Add DiffSynth-Studio integration news ([PR #1034](modelscope/DiffSynth-Studio#1034)) - Credit @Artiprocher in acknowledgements
Video as prompt
No description provided.