-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Closed
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates
Description
Describe the bug
If you have PEFT installed in your environment, then custom_diffusion will not successfully load the cross-attention parameter, leading to a poor generation result. Given the time cost of troubleshooting this issue, the documentation should state that the currently implemented code is incompatible with peft. The code that causes this problem is:
if not USE_PEFT_BACKEND:
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
Reproduction
pip install peft
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="./ckpt/cat"
export INSTANCE_DIR="./data/cat"
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="photo of a <new1> cat" \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--scale_lr \
--hflip \
--modifier_token "<new1>" \
--no_safe_serialization \
--validation_steps=50 \
--validation_prompt="<new1> cat sitting in a bucket" \
--report_to="wandb"
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16,
).to("cuda")
pipeline.unet.load_attn_procs("./ckpt/cat", weight_name="pytorch_custom_diffusion_weights.bin")
pipeline.load_textual_inversion("./ckpt/cat", weight_name="<new1>.bin")
image = pipeline(
"<new1> cat sitting in a bucket",
num_inference_steps=100,
guidance_scale=6.0,
eta=1.0,
).images[0]
image.save("cat.png")
Logs
No response
System Info
torch 2.2.1
diffusers 0.27.0.dev0
peft 0.9.0
transformers 4.38.2
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates