diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index 82c2faabe1..eb76132f10 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -292,8 +292,10 @@ class ChromaImg2ImgPipeline( negative_prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, do_classifier_free_guidance: bool = True, max_sequence_length: int = 512, lora_scale: Optional[float] = None, @@ -310,7 +312,7 @@ class ChromaImg2ImgPipeline( torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): + prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. lora_scale (`float`, *optional*): @@ -335,7 +337,7 @@ class ChromaImg2ImgPipeline( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( prompt=prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, @@ -365,12 +367,13 @@ class ChromaImg2ImgPipeline( " the batch size of `prompt`." ) - negative_prompt_embeds = self._get_t5_prompt_embeds( + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) if self.text_encoder is not None: @@ -378,7 +381,14 @@ class ChromaImg2ImgPipeline( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) - return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids + return ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): @@ -392,52 +402,44 @@ class ChromaImg2ImgPipeline( image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - ) -> torch.Tensor: - """Prepares image embeddings for use in the IP-Adapter. - - Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. - - Args: - ip_adapter_image (`PipelineImageInput`, *optional*): - The input image to extract features from for IP-Adapter. - ip_adapter_image_embeds (`torch.Tensor`, *optional*): - Precomputed image embeddings. - device: (`torch.device`, *optional*): - Torch device. - num_images_per_prompt (`int`, defaults to 1): - Number of images that should be generated per prompt. - do_classifier_free_guidance (`bool`, defaults to True): - Whether to use classifier free guidance or not. - """ + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): device = device or self._execution_device - if ip_adapter_image_embeds is not None: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) - else: - single_image_embeds = ip_adapter_image_embeds - elif ip_adapter_image is not None: - single_image_embeds = self.encode_image(ip_adapter_image, device) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.zeros_like(single_image_embeds) + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) else: - raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] - image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) - if do_classifier_free_guidance: - negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) - image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) - return image_embeds.to(device=device) + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds def check_inputs( self, @@ -448,6 +450,8 @@ class ChromaImg2ImgPipeline( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -483,6 +487,15 @@ class ChromaImg2ImgPipeline( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if prompt_attention_mask is not None and negative_prompt_attention_mask is None: + raise ValueError( + "Cannot provide `prompt_attention_mask` without also providing `negative_prompt_attention_mask`" + ) + + if negative_prompt_attention_mask is not None and prompt_attention_mask is None: + raise ValueError( + "Cannot provide `negative_prompt_attention_mask` without also providing `prompt_attention_mask`" + ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @@ -591,7 +604,7 @@ class ChromaImg2ImgPipeline( height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(height // 2, width // 2, device, dtype) if latents is not None: return latents.to(device=device, dtype=dtype), latent_image_ids @@ -617,6 +630,25 @@ class ChromaImg2ImgPipeline( latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids + def _prepare_attention_mask( + self, + batch_size, + sequence_length, + dtype, + attention_mask=None, + ): + if attention_mask is None: + return attention_mask + + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)], + dim=1, + ) + attention_mask = attention_mask.to(dtype) + + return attention_mask + @property def guidance_scale(self): return self._guidance_scale @@ -656,13 +688,15 @@ class ChromaImg2ImgPipeline( strength: float = 0.8, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -703,11 +737,11 @@ class ChromaImg2ImgPipeline( generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): + latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): + prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. @@ -721,7 +755,7 @@ class ChromaImg2ImgPipeline( Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): + negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. @@ -765,6 +799,8 @@ class ChromaImg2ImgPipeline( negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -794,13 +830,17 @@ class ChromaImg2ImgPipeline( ( prompt_embeds, text_ids, + prompt_attention_mask, negative_prompt_embeds, negative_text_ids, + negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, do_classifier_free_guidance=self.do_classifier_free_guidance, device=device, num_images_per_prompt=num_images_per_prompt, @@ -856,20 +896,55 @@ class ChromaImg2ImgPipeline( latents, ) + attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=prompt_attention_mask, + ) + if self.do_classifier_free_guidance and negative_prompt_attention_mask is not None: + negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=negative_prompt_attention_mask, + ) + attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0) + # 6. Prepare image embeddings - if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: - ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, - self.do_classifier_free_guidance, ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if self.do_classifier_free_guidance and image_embeds is not None and negative_image_embeds is not None: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) - if self.joint_attention_kwargs is None: - self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} - else: - self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -878,9 +953,6 @@ class ChromaImg2ImgPipeline( continue self._current_timestep = t - if ip_adapter_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = ip_adapter_image_embeds - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -892,6 +964,7 @@ class ChromaImg2ImgPipeline( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, + attention_mask=attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0]