-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[core] reuse AttentionMixin for compatible classes
#12463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| for name, module in self.named_children(): | ||
| fn_recursive_attn_processor(name, module, processor) | ||
|
|
||
| # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it's out of the scope for this PR, but I see that a lot of models additionally have a set_default_attn_processor method, usually # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor. Do you think it makes sense to add this method to AttentionMixin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, not yet since AttentionMixin is fairly agnostic to the model-type but set_default_attn_processor relies on some custom attention processor types. For UNet2DConditionModel, we have:
diffusers/src/diffusers/models/unets/unet_2d_condition.py
Lines 762 to 769 in fa468c5
| if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
| processor = AttnAddedKVProcessor() | |
| elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
| processor = AttnProcessor() | |
| else: | |
| raise ValueError( | |
| f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
| ) |
However, for AutoencoderKL Temporal Decoder:
diffusers/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
Lines 269 to 274 in fa468c5
| if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
| processor = AttnProcessor() | |
| else: | |
| raise ValueError( | |
| f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
| ) |
I'd be down to the refactoring, though. Cc: @DN6
dg845
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! I think AuraFlowTransformer2DModel and AudioLDM2UNet2DConditionModel have their attn_processor/set_attn_processor methods deleted but are missing the corresponding change to inherit from AttentionMixin.
|
Thanks for those catches, @dg845. Should have been fixed by now. |
|
LGTM :) |
|
@DN6 okay to go? |
|
@DN6 a gentle ping. |
What does this PR do?
Many models use
"# Copied from ..."implementations ofattn_processorsandset_attn_processor. They are basically the same as what we have implemented indiffusers/src/diffusers/models/attention.py
Line 39 in 693d8a3
This PR makes those models inherit from
AttentionMixinand removes the copied-over implementations.I decided to leave
fuse_qkv_projectionsandunfuse_qkv_projectionsout of this PR because some models don't have attention processors implemented in a way that would make this seamless. But the methods removed in this PR should be very harmless.