[ControlNet] Adds controlnet for SanaTransformer#11040
[ControlNet] Adds controlnet for SanaTransformer#11040a-r-r-o-w merged 24 commits intohuggingface:mainfrom
Conversation
eef6240 to
6a62c3e
Compare
6c71ca6 to
d698d81
Compare
fc00d13 to
7f3cbc5
Compare
|
@ishan-modi Sorry about the slow review here. The team is at an offsite for this week and taking a break, but we'll try to merge asap once we're back next week. Thanks for the awesome work |
|
All good man ! enjoy your offsite. |
|
gentle ping @a-r-r-o-w |
|
Hi, sorry for the delay. Testing now and hopefully can merge soon 🤗 |
a-r-r-o-w
left a comment
There was a problem hiding this comment.
Thanks for the awesome work! Looks very close to merge except for a few more changes. LMK if I can help with any 🤗
| timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) | ||
|
|
||
| # controlnet(s) inference | ||
| controlnet_block_samples = self.controlnet( |
There was a problem hiding this comment.
The inference example in the docstring errors out for me here. This is because ControlNet is loaded in bf16, but latent_model_input is moved to self.transformer.dtype, which is fp16, due to the following code in encode_prompt:
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)Let's do this:
- Remove the above code occurence for determining dtype in encode_prompt
- Return the prompt_embeds in the same dtype as the text encoder
- Perform any dtype casting within the
__call__method based on thecontrolnet_dtypeandtransformer_dtype(create these variables similar to how it's done in Wan
Also, the following are the dtypes of each component:
- vae: bf16
- text_encoder: bf16
- transformer: fp16
- controlnet: bf16
Just so that I'm up to speed, is this expected?
There was a problem hiding this comment.
Can you point me to the docstring that loads controlnet in bf16 ?
I think I overlooked the following:
- controlnet for SANA_600M is supposed to use fp16 here
- controlnet for SANA_1600M is supposed to use bf16 here
Most of the doc loads controlnet into fp16, but I guess it needs to be more generic as you mentioned
There was a problem hiding this comment.
I'm referring to the example code that's at the top of the file:
import torch
from diffusers import SanaControlNetModel, SanaControlNetPipeline
from diffusers.utils import load_image
controlnet = SanaControlNetModel.from_pretrained(
"ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
)
pipe = SanaControlNetPipeline.from_pretrained(
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
variant="fp16",
torch_dtype=torch.float16,
controlnet=controlnet,
)
pipe.to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
cond_image = load_image(
"https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
)
prompt = 'a cat with a neon sign that says "Sana"'
image = pipe(
prompt,
control_image=cond_image,
).images[0]
image.save("output.png")There was a problem hiding this comment.
I think:
- I think we should update the docstring example to use fp16 for both contronet and transformer, unless there is a special reason to do it this way, i.e controlnet in bf16 while transformer in fp16
- the changes @a-r-r-o-w proposed here [ControlNet] Adds controlnet for SanaTransformer #11040 (comment) sounds good. I think all the sana pipelines should have same encode_prompt methods, no? if so, let's not remove the
#Copied frominencode_prompt, and update the one in pipeline_sana.py and make sure changes applied to all sana pipelines
| if self.transformer is not None: | ||
| dtype = self.transformer.dtype | ||
| elif self.text_encoder is not None: | ||
| dtype = self.text_encoder.dtype | ||
| else: | ||
| dtype = None |
There was a problem hiding this comment.
text_encoder cannot be None. transformer can be None since we should be able to run encode_prompt without loading the transformer.
Let's make sure this method returns embeds in the same dtype as text encoder and do casting in __call__
There was a problem hiding this comment.
ohh I think we need to make sure prompt_embeds is not None code path can work without text_encoders loaded too, no?
in modular, we started to make it a way so that you only run encode_prompt when you actually need to encode prompt, that's not the case here yet
There was a problem hiding this comment.
@a-r-r-o-w let me know if I should change the current version to
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
Once confirm I will make similar changes to pipeline_sana.py
There was a problem hiding this comment.
Oh okay, based on YiYi's comment, let's do this for both pipelines
| >>> from diffusers.utils import load_image | ||
|
|
||
| >>> controlnet = SanaControlNetModel.from_pretrained( | ||
| ... "ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16 |
There was a problem hiding this comment.
@lawrence-cj Could we host the controlnet checkpoint in the Efficient-Large-Model org? We generally don't merge without officially hosted weights unless it's necessary for a quick release (which we then update later anyway).
@ishan-modi Please feel free to mention your hosted controlnet model in the docs 🤗
There was a problem hiding this comment.
Yah, I would like to do it and test the PR at the same time. @a-r-r-o-w
There was a problem hiding this comment.
There was a problem hiding this comment.
Awesome, thank you @lawrence-cj! I'll run some final tests and merge the PR in a few hours
There was a problem hiding this comment.
@lawrence-cj The checkpoint does not seem accessible yet and I get a 404. I'll go ahead and merge this PR for now, and we can update the docs/examples with the official checkpoint in a follow up
ff92747 to
3d085a2
Compare
| timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) | ||
|
|
||
| # controlnet(s) inference | ||
| controlnet_block_samples = self.controlnet( |
There was a problem hiding this comment.
I'm referring to the example code that's at the top of the file:
import torch
from diffusers import SanaControlNetModel, SanaControlNetPipeline
from diffusers.utils import load_image
controlnet = SanaControlNetModel.from_pretrained(
"ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.bfloat16
)
pipe = SanaControlNetPipeline.from_pretrained(
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
variant="fp16",
torch_dtype=torch.float16,
controlnet=controlnet,
)
pipe.to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
cond_image = load_image(
"https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
)
prompt = 'a cat with a neon sign that says "Sana"'
image = pipe(
prompt,
control_image=cond_image,
).images[0]
image.save("output.png")09359e6 to
dea5de5
Compare
a-r-r-o-w
left a comment
There was a problem hiding this comment.
Thanks, LGTM! Going to try testing out the model again and we can merge once we have the official hosted checkpoint by Junsong 🤗
|
LMK if you need any help with the failing tests. They seem to be because of a misplace |
fbe517b to
b973cd0
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
a-r-r-o-w
left a comment
There was a problem hiding this comment.
Awesome work @ishan-modi, thanks a lot!
Just some final changes
|
Congrats! Thank you so much @ishan-modi |
What does this PR do?
Fixes #10772, #11019, #11116
Who can review?
@yiyixuxu