Skip to content

“unet.enable_xformers_memory_efficient_attention()” doesn't work in train_text_to_image_lora.py #2459

@ZihaoW123

Description

@ZihaoW123

Describe the bug

unet.enable_xformers_memory_efficient_attention() should be placed after unet.set_attn_processor(lora_attn_procs) , otherwise lora's weight will not use xformers.

I can use enable_xformers_memory_efficient_attention correctly after this modification:

  unet.set_attn_processor(lora_attn_procs) 

   if args.enable_xformers_memory_efficient_attention:
       if is_xformers_available():
           import xformers

           xformers_version = version.parse(xformers.__version__)
           if xformers_version == version.parse("0.0.16"):
               logger.warn(
                   "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
               )
           unet.enable_xformers_memory_efficient_attention()
       else:
           raise ValueError("xformers is not available. Make sure it is installed correctly") 
   
   lora_layers = AttnProcsLayers(unet.attn_processors)

Reproduction

None

Logs

No response

System Info

  • diffusers version: 0.14.0.dev0
  • Platform: Linux-5.4.56.bsk.10-amd64-x86_64-with-glibc2.31
  • Python version: 3.9.2
  • PyTorch version (GPU?): 1.13.1 (True)
  • Huggingface_hub version: 0.12.1
  • Transformers version: 4.26.1
  • Accelerate version: 0.16.0
  • xFormers version: 0.0.17.dev461
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions