Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Nice Work !! just a heads up this PR might have conflicts with #11040 if merged first |
|
Wonderful work. Since SANA-Sprint and SANA-1.5 follow the same architecture, so this PR would make SANA-1.5 work as well. |
* 1. update conversion script for sana1.5; 2. add conversion script for sana-sprint; * seperate sana and sana-sprint conversion scripts; * update for upstream * fix the } bug * add a doc for SanaSprintPipeline; * minor update; * make style && make quality
|
@bot /style |
|
@bot/ style |
|
Style fixes have been applied. View the workflow run here. |
|
cc @lawrence-cj can you do a review? |
a-r-r-o-w
left a comment
There was a problem hiding this comment.
Really amazing work! Can't wait for the release ❤️
| >>> from diffusers import SanaPipeline | ||
|
|
||
| >>> pipe = SanaPipeline.from_pretrained( |
There was a problem hiding this comment.
Example to be updated to SanaSprintPipeline
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
| # scaled_dot_product_attention expects attention_mask shape to be | ||
| # (batch, heads, source_length, target_length) | ||
| attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
There was a problem hiding this comment.
In other recent models we found that attention mask with shape [B, 1, 1, N] is faster as the total size is smaller and PyTorch's broadcasting handles it. Something to look into, if we see a benefit all occurrences of this code can be updated.
| latents = latents.to(self.vae.dtype) | ||
| try: | ||
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||
| except torch.cuda.OutOfMemoryError as e: |
There was a problem hiding this comment.
Nice!
For XPU we need to use torch.OutOfMemoryError, also looks like that will work on CUDA.
https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/test_xpu.py#L446
https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/test_cuda.py#L3950
https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/test_cuda.py#L4154
add a note about max_timesteps
Co-authored-by: Aryan <[email protected]>
| try: | ||
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||
| except torch.cuda.OutOfMemoryError as e: | ||
| warnings.warn( |
There was a problem hiding this comment.
Should this be logger.warning()?
| else: | ||
| # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here | ||
| self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float() | ||
| print(f"Set timesteps: {self.timesteps}") |
There was a problem hiding this comment.
| print(f"Set timesteps: {self.timesteps}") |
vibe tests with different timesteps settings