Skip to content

Unsuccessful cross-attention weight loading in Custom Diffusion #7261

@Rbrq03

Description

@Rbrq03

Describe the bug

If you have PEFT installed in your environment, then custom_diffusion will not successfully load the cross-attention parameter, leading to a poor generation result. Given the time cost of troubleshooting this issue, the documentation should state that the currently implemented code is incompatible with peft. The code that causes this problem is:

if not USE_PEFT_BACKEND:
            if _pipeline is not None:
                for _, component in _pipeline.components.items():
                    if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
                        is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
                        is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)

                        logger.info(
                            "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
                        )
                        remove_hook_from_module(component, recurse=is_sequential_cpu_offload)

            # only custom diffusion needs to set attn processors
            if is_custom_diffusion:
                self.set_attn_processor(attn_processors)

            # set lora layers
            for target_module, lora_layer in lora_layers_list:
                target_module.set_lora_layer(lora_layer)

            self.to(dtype=self.dtype, device=self.device)

            # Offload back.
            if is_model_cpu_offload:
                _pipeline.enable_model_cpu_offload()
            elif is_sequential_cpu_offload:
                _pipeline.enable_sequential_cpu_offload()
            # Unsafe code />

Reproduction

pip install peft 

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="./ckpt/cat"
export INSTANCE_DIR="./data/cat"

accelerate launch train_custom_diffusion.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="photo of a <new1> cat"  \
  --resolution=512  \
  --train_batch_size=2  \
  --learning_rate=1e-5  \
  --lr_warmup_steps=0 \
  --max_train_steps=1000 \
  --scale_lr \
  --hflip  \
  --modifier_token "<new1>" \
  --no_safe_serialization \
  --validation_steps=50 \
  --validation_prompt="<new1> cat sitting in a bucket" \
  --report_to="wandb" 


import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16,
).to("cuda")
pipeline.unet.load_attn_procs("./ckpt/cat", weight_name="pytorch_custom_diffusion_weights.bin")
pipeline.load_textual_inversion("./ckpt/cat", weight_name="<new1>.bin")

image = pipeline(
    "<new1> cat sitting in a bucket",
    num_inference_steps=100,
    guidance_scale=6.0,
    eta=1.0,
).images[0]
image.save("cat.png")

Logs

No response

System Info

torch 2.2.1
diffusers 0.27.0.dev0
peft 0.9.0
transformers 4.38.2

Who can help?

@sayakpaul @yiyixuxu @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions