mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -292,8 +292,10 @@ class ChromaImg2ImgPipeline(
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
@@ -310,7 +312,7 @@ class ChromaImg2ImgPipeline(
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
@@ -335,7 +337,7 @@ class ChromaImg2ImgPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -365,12 +367,13 @@ class ChromaImg2ImgPipeline(
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
@@ -378,7 +381,14 @@ class ChromaImg2ImgPipeline(
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
|
||||
return (
|
||||
prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_text_ids,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
@@ -392,52 +402,44 @@ class ChromaImg2ImgPipeline(
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
return image_embeds
|
||||
|
||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Prepares image embeddings for use in the IP-Adapter.
|
||||
|
||||
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
|
||||
|
||||
Args:
|
||||
ip_adapter_image (`PipelineImageInput`, *optional*):
|
||||
The input image to extract features from for IP-Adapter.
|
||||
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
||||
Precomputed image embeddings.
|
||||
device: (`torch.device`, *optional*):
|
||||
Torch device.
|
||||
num_images_per_prompt (`int`, defaults to 1):
|
||||
Number of images that should be generated per prompt.
|
||||
do_classifier_free_guidance (`bool`, defaults to True):
|
||||
Whether to use classifier free guidance or not.
|
||||
"""
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
if ip_adapter_image_embeds is not None:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
|
||||
else:
|
||||
single_image_embeds = ip_adapter_image_embeds
|
||||
elif ip_adapter_image is not None:
|
||||
single_image_embeds = self.encode_image(ip_adapter_image, device)
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
|
||||
image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image in ip_adapter_image:
|
||||
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
else:
|
||||
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
|
||||
if not isinstance(ip_adapter_image_embeds, list):
|
||||
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
||||
|
||||
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds = []
|
||||
for single_image_embeds in image_embeds:
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -448,6 +450,8 @@ class ChromaImg2ImgPipeline(
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
@@ -483,6 +487,15 @@ class ChromaImg2ImgPipeline(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
if prompt_attention_mask is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError(
|
||||
"Cannot provide `prompt_attention_mask` without also providing `negative_prompt_attention_mask`"
|
||||
)
|
||||
|
||||
if negative_prompt_attention_mask is not None and prompt_attention_mask is None:
|
||||
raise ValueError(
|
||||
"Cannot provide `negative_prompt_attention_mask` without also providing `prompt_attention_mask`"
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
@@ -591,7 +604,7 @@ class ChromaImg2ImgPipeline(
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
latent_image_ids = self._prepare_latent_image_ids(height // 2, width // 2, device, dtype)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
@@ -617,6 +630,25 @@ class ChromaImg2ImgPipeline(
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, latent_image_ids
|
||||
|
||||
def _prepare_attention_mask(
|
||||
self,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
dtype,
|
||||
attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# Extend the prompt attention mask to account for image tokens in the final sequence
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
|
||||
dim=1,
|
||||
)
|
||||
attention_mask = attention_mask.to(dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -656,13 +688,15 @@ class ChromaImg2ImgPipeline(
|
||||
strength: float = 0.8,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -703,11 +737,11 @@ class ChromaImg2ImgPipeline(
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
@@ -721,7 +755,7 @@ class ChromaImg2ImgPipeline(
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
@@ -765,6 +799,8 @@ class ChromaImg2ImgPipeline(
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -794,13 +830,17 @@ class ChromaImg2ImgPipeline(
|
||||
(
|
||||
prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_text_ids,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
@@ -856,20 +896,55 @@ class ChromaImg2ImgPipeline(
|
||||
latents,
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_attention_mask(
|
||||
batch_size=latents.shape[0],
|
||||
sequence_length=image_seq_len,
|
||||
dtype=latents.dtype,
|
||||
attention_mask=prompt_attention_mask,
|
||||
)
|
||||
if self.do_classifier_free_guidance and negative_prompt_attention_mask is not None:
|
||||
negative_attention_mask = self._prepare_attention_mask(
|
||||
batch_size=latents.shape[0],
|
||||
sequence_length=image_seq_len,
|
||||
dtype=latents.dtype,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0)
|
||||
|
||||
# 6. Prepare image embeddings
|
||||
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
|
||||
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
||||
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
||||
):
|
||||
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
||||
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
||||
):
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
image_embeds = None
|
||||
negative_image_embeds = None
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
||||
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
negative_ip_adapter_image,
|
||||
negative_ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
if self.do_classifier_free_guidance and image_embeds is not None and negative_image_embeds is not None:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
|
||||
else:
|
||||
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -878,9 +953,6 @@ class ChromaImg2ImgPipeline(
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
if ip_adapter_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = ip_adapter_image_embeds
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@@ -892,6 +964,7 @@ class ChromaImg2ImgPipeline(
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
attention_mask=attention_mask,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
Reference in New Issue
Block a user