Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
08a1828
add ip-adapter
okotaku Sep 8, 2023
c4646f8
modularize.
sayakpaul Nov 2, 2023
f3755d4
add to inits.
sayakpaul Nov 2, 2023
5887af0
fix
sayakpaul Nov 2, 2023
f9aaa54
fix
sayakpaul Nov 2, 2023
a45292b
fix
sayakpaul Nov 2, 2023
023c2b7
fix
sayakpaul Nov 2, 2023
8fe3064
fix
sayakpaul Nov 2, 2023
dded7c4
fix
sayakpaul Nov 2, 2023
651302b
fix
sayakpaul Nov 2, 2023
f10eb25
fix
sayakpaul Nov 2, 2023
f051c9e
fix
sayakpaul Nov 2, 2023
6031383
device placement
sayakpaul Nov 2, 2023
95e38ac
device placement
sayakpaul Nov 2, 2023
3d69688
device placement fix.
sayakpaul Nov 2, 2023
cacee6d
Merge branch 'main' into feat/ip_adapter
okotaku Nov 3, 2023
351180f
fix import
okotaku Nov 3, 2023
2e83d6c
composable ip adapter module
sayakpaul Nov 3, 2023
1d64cb8
add image_encoder to sd as optional components
Nov 7, 2023
70fae5c
add image_prompt arg
Nov 7, 2023
3aaaa23
move image_projection to unet, refactor
Nov 8, 2023
2154d01
update comments
Nov 8, 2023
2807ee3
fix
sayakpaul Nov 8, 2023
bc52810
make image_encoder default to None.
sayakpaul Nov 8, 2023
eaf94bb
fully delegate the image encoding logic.
sayakpaul Nov 8, 2023
c22cd90
Merge branch 'main' into feat/ip_adapter
sayakpaul Nov 8, 2023
7cf7f70
debug
sayakpaul Nov 8, 2023
03e2961
fix
sayakpaul Nov 8, 2023
982a557
fix
sayakpaul Nov 8, 2023
6059099
fix:
sayakpaul Nov 8, 2023
c56503b
fix
sayakpaul Nov 8, 2023
59c933a
separate the loacder.
sayakpaul Nov 8, 2023
7ece033
circular import problem
sayakpaul Nov 8, 2023
4cb0432
circular imports.
sayakpaul Nov 8, 2023
17223d4
added_cond_kwargs not needed now.
sayakpaul Nov 8, 2023
8001d24
remove save_ip_adapter.
sayakpaul Nov 8, 2023
7887ba7
remove ip adapter pipeline from the face of the earth
sayakpaul Nov 8, 2023
46c668b
refactor __call__
sayakpaul Nov 8, 2023
6e28231
fix init.
sayakpaul Nov 8, 2023
3241c96
Merge branch 'main' into feat/ip_adapter
sayakpaul Nov 8, 2023
ef937be
remove none
sayakpaul Nov 8, 2023
7043443
image_encoder
sayakpaul Nov 8, 2023
d7e390f
module registration
sayakpaul Nov 8, 2023
86b0e4a
does defaulting to None work for modules?
sayakpaul Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@
import requests
import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download, model_info
from packaging import version
from torch import nn

from . import __version__
from .models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterControlNetAttnProcessor,
IPAdapterControlNetAttnProcessor2_0,
)
Comment on lines +31 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates a circular dependency (regarding https://github.com/huggingface/diffusers/pull/4944/files#r1386535787). Let me know if there's a better way to address this.

from .models.embeddings import ImageProjection
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from .utils import (
DIFFUSERS_CACHE,
Expand Down Expand Up @@ -72,6 +82,9 @@
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"

IP_ADAPTER_WEIGHT_NAME = "pytorch_ip_adapter_weights.bin"
IP_ADAPTER_WEIGHT_NAME_SAFE = "pytorch_ip_adapter_weights.safetensors"

LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."


Expand Down Expand Up @@ -3329,3 +3342,165 @@ def _remove_text_encoder_monkey_patch(self):
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)


class IPAdapterMixin:
"""Mixin for handling IP Adapters."""

def set_ip_adapter(self):
unet = self.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).to(dtype=unet.dtype, device=unet.device)

unet.set_attn_processor(attn_procs)

if hasattr(self, "controlnet"):
attn_processor_class = (
IPAdapterControlNetAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else IPAdapterControlNetAttnProcessor
)
self.pipeline.controlnet.set_attn_processor(attn_processor_class())

def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:

- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")

self.set_ip_adapter()

# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
# TODO (sayakpaul): incorporate safetensors

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict

keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")

# Handle image projection layers.
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4

image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
)
image_projection.to(dtype=self.unet.dtype, device=self.unet.device)

diffusers_state_dict = {}

diffusers_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)

image_projection.load_state_dict(diffusers_state_dict)
self.image_projection = image_projection.to(device=self.unet.device, dtype=self.unet.dtype)

# Handle IP-Adapter cross-attention layers.
ip_layers = torch.nn.ModuleList(
[
module if isinstance(module, nn.Module) else nn.Identity()
for module in self.unet.attn_processors.values()
]
)
ip_layers.load_state_dict(state_dict["ip_adapter"])

def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = scale
Loading