Skip to content

error of num_images_per_prompt in the flux pipeline #9215

@rardz

Description

@rardz

in the _get_clip_prompt_embeds in flux pipeline, it is:
`
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)

    # Use pooled output of CLIPTextModel
    prompt_embeds = prompt_embeds.pooler_output
    prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)

`
prompt_embeds should repeated interleavely on batch, but as a 2d tensor, prompt_embeds here will finally be repeated but not interleavely, which is not corretlly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions