mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Flux Redux] add prompt & multiple image input (#10056)
* add multiple prompts to flux redux --------- Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -142,6 +142,45 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
image,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
prompt_embeds_scale=1.0,
|
||||
pooled_prompt_embeds_scale=1.0,
|
||||
):
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)):
|
||||
raise ValueError(
|
||||
f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images"
|
||||
)
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if isinstance(prompt_embeds_scale, list) and (
|
||||
isinstance(image, list) and len(prompt_embeds_scale) != len(image)
|
||||
):
|
||||
raise ValueError(
|
||||
f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images"
|
||||
)
|
||||
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
image = self.feature_extractor.preprocess(
|
||||
@@ -334,6 +373,12 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
|
||||
pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
r"""
|
||||
@@ -345,6 +390,16 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
|
||||
make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
|
||||
are not loaded.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -356,6 +411,17 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
image,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
prompt_embeds_scale=prompt_embeds_scale,
|
||||
pooled_prompt_embeds_scale=pooled_prompt_embeds_scale,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if image is not None and isinstance(image, Image.Image):
|
||||
batch_size = 1
|
||||
@@ -363,6 +429,13 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
batch_size = len(image)
|
||||
else:
|
||||
batch_size = image.shape[0]
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
prompt = batch_size * [prompt]
|
||||
if isinstance(prompt_embeds_scale, float):
|
||||
prompt_embeds_scale = batch_size * [prompt_embeds_scale]
|
||||
if isinstance(pooled_prompt_embeds_scale, float):
|
||||
pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare image embeddings
|
||||
@@ -378,24 +451,38 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
pooled_prompt_embeds,
|
||||
_,
|
||||
) = self.encode_prompt(
|
||||
prompt=[""] * batch_size,
|
||||
prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=1,
|
||||
max_sequence_length=512,
|
||||
lora_scale=None,
|
||||
)
|
||||
else:
|
||||
if prompt is not None:
|
||||
logger.warning(
|
||||
"prompt input is ignored when text encoders are not loaded to the pipeline. "
|
||||
"Make sure to explicitly load the text encoders to enable prompt input. "
|
||||
)
|
||||
# max_sequence_length is 512, t5 encoder hidden size is 4096
|
||||
prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
|
||||
# pooled_prompt_embeds is 768, clip text encoder hidden size
|
||||
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
|
||||
|
||||
# Concatenate image and text embeddings
|
||||
# scale & concatenate image and text embeddings
|
||||
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
|
||||
|
||||
prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
|
||||
pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[
|
||||
:, None
|
||||
]
|
||||
|
||||
# weighted sum
|
||||
prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
|
||||
pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user