Skip to content

Conversation

@ayushtues
Copy link
Contributor

@ayushtues ayushtues commented Jul 31, 2023

This PR implements BLIP Diffusion as discussed in #4274

Notion for tracking progress/brainstorming : link

Model/Pipeline Description

BLIP diffusion (Salesforce): https://dxli94.github.io/BLIP-Diffusion-website/

BLIP diffusion enables subject-driven zero-shot image generation which is probably its best USP.

Code with pre-trained weights: https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion
Paper: https://arxiv.org/abs/2305.14720

Abstract:

Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task, called prompted context generation, which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.

From the official website mentioned above:
image

TODO

  • Implement BLIPDiffusionPipeline
  • Script to convert pretrained weights into diffusers checkpoints
  • Add model cards for checkpoints and move checkpoints to appropriate repositories
  • Write tests
  • Add docstrings for new classes
  • Create documentation
  • Add usage example(s)

HF Model Link : https://huggingface.co/ayushtues/blipdiffusion/tree/main

Usage Examples

Zero-Shot Subject Driven Generation

from diffusers.pipelines import BlipDiffusionPipeline
from diffusers.utils import load_image


blip_diffusion_pipe= BlipDiffusionPipeline.from_pretrained('ayushtues/blipdiffusion')
blip_diffusion_pipe.to('cuda')

cond_subject = ["dog"]
tgt_subject = ["dog"]
text_prompt_input = ["swimming underwater"]


cond_image = load_image("https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg")
num_output = 1

iter_seed = 88888
guidance_scale = 7.5
num_inference_steps = 50
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"

for i in range(num_output):
    output = blip_diffusion_pipe(
        text_prompt_input,
        cond_image,
        cond_subject,
        tgt_subject,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        neg_prompt=negative_prompt,
        height=512,
        width=512,
    )

Input
image
Output
image

Controlled subject-driven generation Canny-edge

from diffusers.pipelines import BlipDiffusionControlNetPipeline
from diffusers.utils import load_image
from controlnet_aux import CannyDetector

blip_diffusion_pipe= BlipDiffusionControlNetPipeline.from_pretrained("ayushtues/blipdiffusion-controlnet")
blip_diffusion_pipe.to('cuda')

style_subject = ["flower"] # subject that defines the style
tgt_subject = ["teapot"]  # subject to generate.
text_prompt = ["on a marble table"]
cldm_cond_image = load_image("https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg").resize((512, 512))
canny = CannyDetector()
cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type='pil')
cldm_cond_image = [cldm_cond_image ]

style_image = load_image("https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg")


num_output = 1
iter_seed = 88888
guidance_scale = 7.5
num_inference_steps = 50
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"

for i in range(num_output):
    output = blip_diffusion_pipe(
        text_prompt,
        style_image,
         cldm_cond_image,
        style_subject,
        tgt_subject,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        neg_prompt=negative_prompt,
        height=512,
        width=512,
    )

Canny edge based Controlnet example -
Input
image
Conditioning image for Canny Edge
image
Output
image

Controlled subject-driven generation Scribble

from diffusers.pipelines import BlipDiffusionControlNetPipeline
from diffusers.utils import load_image
from controlnet_aux import HEDdetector

blip_diffusion_pipe= BlipDiffusionControlNetPipeline.from_pretrained("ayushtues/blipdiffusion-controlnet")
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-scribble")
blip_diffusion_pipe.controlnet = controlnet
blip_diffusion_pipe.to('cuda')

style_subject = ["flower"] # subject that defines the style
tgt_subject = ["bag"]  # subject to generate.
text_prompt = ["on a table"]
cldm_cond_image = load_image("https://huggingface.co/lllyasviel/sd-controlnet-scribble/resolve/main/images/bag.png" ).resize((512, 512))
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
cldm_cond_image = hed(cldm_cond_image)
cldm_cond_image = [cldm_cond_image ]

style_image = load_image("https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg")


num_output = 1
iter_seed = 88888
guidance_scale = 7.5
num_inference_steps = 50
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"

for i in range(num_output):
    output = blip_diffusion_pipe(
        text_prompt,
        style_image,
         cldm_cond_image,
        style_subject,
        tgt_subject,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        neg_prompt=negative_prompt,
        height=512,
        width=512,
    )

Scribble example -
Input
image
Conditioning image for Scribble
image
Output
image

CC

@sayakpaul

@oumad
Copy link

oumad commented Jul 31, 2023

It has been 2 months since this was out, I can't believe how almost no one mentions it, let alone implement it.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ayushtues
Copy link
Contributor Author

ayushtues commented Aug 7, 2023

Ported the ViT visual encoder checkpoints used: https://huggingface.co/ayushtues/blipdiffusion/tree/main/vision_encoder

from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2VisionConfig, Blip2VisionModel

image_input_dummy = torch.zeros(1, 3, 224, 224)
visual_encoder = Blip2VisionModel.from_pretrained('ayushtues/blipdiffusion', subfolder='visual_encoder')
image_embed = visual_encoder(image_input_dummy, return_dict=True).last_hidden_state

Next step - Porting the Blip2QFormer

@NielsRogge
Copy link

@ayushtues maybe you can directly use BLIP-2 in Transformers as a dependency, rather than reimplementing it in Diffusers (similar to how CLIP or T5 aren't reimplemented in diffusers for Stable Diffusion)?

You can also do from transformers.models.blip_2.modeling_blip2 import Blip2VisionModel for instance, in case you only need the vision encoder

@ayushtues
Copy link
Contributor Author

Hey @NielsRogge I originally intended to do that, but as mentioned in huggingface/transformers#25245 the implementation of Blip2 in Transformers didn't support multimodal feature extraction, so I went ahead to do a local implementation in diffusers.

If the feature gets added in the transformers implementation, we can shift to directly importing it

@ayushtues
Copy link
Contributor Author

Update, was able to port the model to diffusers, although a lot of the code needs refactoring/reusing and better integration

Colab link : https://colab.research.google.com/drive/1PDlO8-1kPnhTUOmQBv5a2cIBTdYp_7Pi?usp=sharing
HF Model link : https://huggingface.co/ayushtues/blipdiffusion/tree/main

cond_subject = "dog"
tgt_subject = "dog"
text_prompt_input = "swimming underwater"

Input image -
OIP

Output image -
download

@ayushtues
Copy link
Contributor Author

ayushtues commented Aug 13, 2023

Hey @sayakpaul can you please do a review of this PR?

@sayakpaul
Copy link
Member

@ayushtues hopefully final set of comments from my end before we can merge:

  • Resolve the open comments.
  • Change the examples to reflect the commonly followed practices:

Zero-shot:

from diffusers.pipelines import BlipDiffusionPipeline
from diffusers.utils import load_image
import torch

blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
    "ayushtues/blipdiffusion", torch_dtype=torch.float16
).to("cuda")

cond_subject = "dog"
tgt_subject = "dog"
text_prompt_input = "swimming underwater"

cond_image = load_image(
    "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
)

iter_seed = 88888
guidance_scale = 7.5
num_inference_steps = 25
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"

output = blip_diffusion_pipe(
    text_prompt_input,
    cond_image,
    cond_subject,
    tgt_subject,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    neg_prompt=negative_prompt,
    height=512,
    width=512,
).images
output[0].save("image.png")

Control-guided (Canny):

from diffusers.pipelines import BlipDiffusionControlNetPipeline
from diffusers.utils import load_image
from controlnet_aux import CannyDetector

blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
    "ayushtues/blipdiffusion-controlnet", torch_dtype=torch.float16
).to("cuda")

style_subject = "flower"  # subject that defines the style
tgt_subject = "teapot"  # subject to generate.
text_prompt = "on a marble table"

cldm_cond_image = load_image(
    "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
).resize((512, 512))
canny = CannyDetector()
cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
style_image = load_image(
    "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
)

guidance_scale = 7.5
num_inference_steps = 50
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"

output = blip_diffusion_pipe(
    text_prompt,
    style_image,
    cldm_cond_image,
    style_subject,
    tgt_subject,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    neg_prompt=negative_prompt,
    height=512,
    width=512,
).images
output[0].save("image.png")

Control-guided (scribble):

from diffusers.pipelines import BlipDiffusionControlNetPipeline
from diffusers.utils import load_image
from controlnet_aux import HEDdetector

blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
    "ayushtues/blipdiffusion-controlnet"
)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-scribble")
blip_diffusion_pipe.controlnet = controlnet
blip_diffusion_pipe.to("cuda")

style_subject = "flower"  # subject that defines the style
tgt_subject = "bag"  # subject to generate.
text_prompt = "on a table"
cldm_cond_image = load_image(
    "https://huggingface.co/lllyasviel/sd-controlnet-scribble/resolve/main/images/bag.png"
).resize((512, 512))
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
cldm_cond_image = hed(cldm_cond_image)
style_image = load_image(
    "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
)

guidance_scale = 7.5
num_inference_steps = 50
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"

output = blip_diffusion_pipe(
    text_prompt,
    style_image,
    cldm_cond_image,
    style_subject,
    tgt_subject,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    neg_prompt=negative_prompt,
    height=512,
    width=512,
).images
output[0].save("image.png")

Would be great if you could update the documentation to reflect this and also the model cards.

Then we can edit the checkpoint path to reflect SalesForce, transfer them, and finally merge the PR.

Let me know if anything's unclear.

@sayakpaul
Copy link
Member

Another point is have you tried using the pipeline on various subjects and seeing if it's able to faithfully render them in the outputs?

For example, I tried the zero-shot rendition pipeline on this image with the following parameters:

cond_subject = "backpack"
tgt_subject = "backpack"
text_prompt_input = "in a busy street"

But it didn't faithfully render the subject:

image

Is this expected?

@ayushtues
Copy link
Contributor Author

ayushtues commented Sep 20, 2023

image
image

These are some examples I seem to be getting, I think they are okay? ( Multiple samples, since I am not fixing the generator )

@sayakpaul
Copy link
Member

But aren't they deviating from the subject a bit or is that expected?

@dxli94
Copy link

dxli94 commented Sep 20, 2023

Thanks @ayushtues @sayakpaul for the addition. The results look as expected. In zero-shot inference, the subject appearance does deviate a bit as suggested in the reported metrics (CLIP-I, DINO). More similar results can be obtained by few-step fine-tuning.

@ayushtues
Copy link
Contributor Author

^ @sayakpaul, if the original author says so :P

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to merge from my side

@sayakpaul
Copy link
Member

Alright. Really, great work here Ayush!

We have asked internally to move the checkpoints to hf.co/Salesforce. Once this is done and the checkpoint paths have been updated in the model cards (should be done before we transfer actually) and the docs, I will merge the PR :)

@ayushtues
Copy link
Contributor Author

Alright. Really, great work here Ayush!

We have asked internally to move the checkpoints to hf.co/Salesforce. Once this is done and the checkpoint paths have been updated in the model cards (should be done before we transfer actually) and the docs, I will merge the PR :)

Let me know when the transfer is done, I'll do the other changes

@sayakpaul
Copy link
Member

@ayushtues we have got approval from @dxli94 to do the transfer. Please update the checkpoint paths and once done, let me know here. Will transfer and merge.

@ayushtues
Copy link
Contributor Author

Hey @sayakpaul updated to Salesforce/blipdiffusion & Salesforce/blipdiffusion-controlnet, let me know if anything else is needed

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work! Thanks so much for your patience and for iterating!

As soon as the transfer is done, will merge and ship this 🚀

@sayakpaul
Copy link
Member

Trasfer complete! Merging!

@sayakpaul sayakpaul merged commit 157c901 into huggingface:main Sep 21, 2023
@yanchaoguo
Copy link

Trasfer complete! Merging!

Excellent! I have used it
企业微信截图_16953670298817
image

@yanchaoguo
Copy link

how to load lora file using blip diffusion @sayakpaul

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Add BLIP Diffusion skeleton

* Add other model components

* Add BLIP2, need to change it for now

* Fix pipeline imports

* Load pretrained ViT

* Make qformer fwd pass same

* Replicate fwd passes

* Fix device bug

* Add accelerate functions

* Remove extra functions from Blip2

* Minor bug

* Integrate initial review changes

* Refactoring

* Refactoring

* Refactor

* Add controlnet

* Refactor

* Update conversion script

* Add image processor

* Shift postprocessing to ImageProcessor

* Refactor

* Fix device

* Add fast tests

* Update conversion script

* Fix checkpoint conversion script

* Integrate review changes

* Integrate reivew changes

* Remove unused functions from test

* Reuse HF image processor in Cond image

* Create new BlipImageProcessor based on transfomers

* Fix image preprocessor

* Minor

* Minor

* Add canny preprocessing

* Fix controlnet preprocessing

* Fix blip diffusion test

* Add controlnet test

* Add initial doc strings

* Integrate review changes

* Refactor

* Update examples

* Remove DDIM comments

* Add copied from for prepare_latents

* Add type anotations

* Add docstrings

* Do black formatting

* Add batch support

* Make tests pass

* Make controlnet tests pass

* Black formatting

* Fix progress bar

* Fix some licensing comments

* Fix imports

* Refactor controlnet

* Make tests faster

* Edit examples

* Black formatting/Ruff

* Add doc

* Minor

Co-authored-by: Patrick von Platen <[email protected]>

* Move controlnet pipeline

* Make tests faster

* Fix imports

* Fix formatting

* Fix make errors

* Fix make errors

* Minor

* Add suggested doc changes

Co-authored-by: Sayak Paul <[email protected]>

* Edit docs

* Fix 16 bit loading

* Update examples

* Edit toctree

* Update docs/source/en/api/pipelines/blip_diffusion.md

Co-authored-by: Sayak Paul <[email protected]>

* Minor

* Add tips

* Edit examples

* Update model paths

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Add BLIP Diffusion skeleton

* Add other model components

* Add BLIP2, need to change it for now

* Fix pipeline imports

* Load pretrained ViT

* Make qformer fwd pass same

* Replicate fwd passes

* Fix device bug

* Add accelerate functions

* Remove extra functions from Blip2

* Minor bug

* Integrate initial review changes

* Refactoring

* Refactoring

* Refactor

* Add controlnet

* Refactor

* Update conversion script

* Add image processor

* Shift postprocessing to ImageProcessor

* Refactor

* Fix device

* Add fast tests

* Update conversion script

* Fix checkpoint conversion script

* Integrate review changes

* Integrate reivew changes

* Remove unused functions from test

* Reuse HF image processor in Cond image

* Create new BlipImageProcessor based on transfomers

* Fix image preprocessor

* Minor

* Minor

* Add canny preprocessing

* Fix controlnet preprocessing

* Fix blip diffusion test

* Add controlnet test

* Add initial doc strings

* Integrate review changes

* Refactor

* Update examples

* Remove DDIM comments

* Add copied from for prepare_latents

* Add type anotations

* Add docstrings

* Do black formatting

* Add batch support

* Make tests pass

* Make controlnet tests pass

* Black formatting

* Fix progress bar

* Fix some licensing comments

* Fix imports

* Refactor controlnet

* Make tests faster

* Edit examples

* Black formatting/Ruff

* Add doc

* Minor

Co-authored-by: Patrick von Platen <[email protected]>

* Move controlnet pipeline

* Make tests faster

* Fix imports

* Fix formatting

* Fix make errors

* Fix make errors

* Minor

* Add suggested doc changes

Co-authored-by: Sayak Paul <[email protected]>

* Edit docs

* Fix 16 bit loading

* Update examples

* Edit toctree

* Update docs/source/en/api/pipelines/blip_diffusion.md

Co-authored-by: Sayak Paul <[email protected]>

* Minor

* Add tips

* Edit examples

* Update model paths

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants