1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-06-17 11:35:14 +05:30
parent 03165b9269
commit 43d041adf4

View File

@@ -292,8 +292,10 @@ class ChromaImg2ImgPipeline(
negative_prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
do_classifier_free_guidance: bool = True,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
@@ -310,7 +312,7 @@ class ChromaImg2ImgPipeline(
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
@@ -335,7 +337,7 @@ class ChromaImg2ImgPipeline(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
@@ -365,12 +367,13 @@ class ChromaImg2ImgPipeline(
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
if self.text_encoder is not None:
@@ -378,7 +381,14 @@ class ChromaImg2ImgPipeline(
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
return (
prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_embeds,
negative_text_ids,
negative_prompt_attention_mask,
)
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
@@ -392,52 +402,44 @@ class ChromaImg2ImgPipeline(
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> torch.Tensor:
"""Prepares image embeddings for use in the IP-Adapter.
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
Args:
ip_adapter_image (`PipelineImageInput`, *optional*):
The input image to extract features from for IP-Adapter.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Precomputed image embeddings.
device: (`torch.device`, *optional*):
Torch device.
num_images_per_prompt (`int`, defaults to 1):
Number of images that should be generated per prompt.
do_classifier_free_guidance (`bool`, defaults to True):
Whether to use classifier free guidance or not.
"""
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
device = device or self._execution_device
if ip_adapter_image_embeds is not None:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
else:
single_image_embeds = ip_adapter_image_embeds
elif ip_adapter_image is not None:
single_image_embeds = self.encode_image(ip_adapter_image, device)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)
for single_ip_adapter_image in ip_adapter_image:
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
image_embeds.append(single_image_embeds[None, :])
else:
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
if not isinstance(ip_adapter_image_embeds, list):
ip_adapter_image_embeds = [ip_adapter_image_embeds]
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)
if do_classifier_free_guidance:
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)
return image_embeds.to(device=device)
ip_adapter_image_embeds = []
for single_image_embeds in image_embeds:
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
def check_inputs(
self,
@@ -448,6 +450,8 @@ class ChromaImg2ImgPipeline(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
@@ -483,6 +487,15 @@ class ChromaImg2ImgPipeline(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_attention_mask is not None and negative_prompt_attention_mask is None:
raise ValueError(
"Cannot provide `prompt_attention_mask` without also providing `negative_prompt_attention_mask`"
)
if negative_prompt_attention_mask is not None and prompt_attention_mask is None:
raise ValueError(
"Cannot provide `negative_prompt_attention_mask` without also providing `prompt_attention_mask`"
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -591,7 +604,7 @@ class ChromaImg2ImgPipeline(
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(height // 2, width // 2, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
@@ -617,6 +630,25 @@ class ChromaImg2ImgPipeline(
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
def _prepare_attention_mask(
self,
batch_size,
sequence_length,
dtype,
attention_mask=None,
):
if attention_mask is None:
return attention_mask
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
dim=1,
)
attention_mask = attention_mask.to(dtype)
return attention_mask
@property
def guidance_scale(self):
return self._guidance_scale
@@ -656,13 +688,15 @@ class ChromaImg2ImgPipeline(
strength: float = 0.8,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -703,11 +737,11 @@ class ChromaImg2ImgPipeline(
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
@@ -721,7 +755,7 @@ class ChromaImg2ImgPipeline(
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
@@ -765,6 +799,8 @@ class ChromaImg2ImgPipeline(
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -794,13 +830,17 @@ class ChromaImg2ImgPipeline(
(
prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_embeds,
negative_text_ids,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
do_classifier_free_guidance=self.do_classifier_free_guidance,
device=device,
num_images_per_prompt=num_images_per_prompt,
@@ -856,20 +896,55 @@ class ChromaImg2ImgPipeline(
latents,
)
attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
attention_mask=prompt_attention_mask,
)
if self.do_classifier_free_guidance and negative_prompt_attention_mask is not None:
negative_attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
attention_mask=negative_prompt_attention_mask,
)
attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0)
# 6. Prepare image embeddings
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if self.do_classifier_free_guidance and image_embeds is not None and negative_image_embeds is not None:
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
else:
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -878,9 +953,6 @@ class ChromaImg2ImgPipeline(
continue
self._current_timestep = t
if ip_adapter_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = ip_adapter_image_embeds
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -892,6 +964,7 @@ class ChromaImg2ImgPipeline(
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
attention_mask=attention_mask,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]