diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index 3032449b81..378f757d11 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -217,6 +217,8 @@ class ChromaImg2ImgPipeline( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) @@ -852,10 +854,6 @@ class ChromaImg2ImgPipeline( lora_scale=lora_scale, ) - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - text_ids = torch.cat([negative_text_ids, text_ids], dim=0) - # 4. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) @@ -913,7 +911,6 @@ class ChromaImg2ImgPipeline( dtype=latents.dtype, attention_mask=negative_prompt_attention_mask, ) - attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0) # 6. Prepare image embeddings if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( @@ -928,6 +925,9 @@ class ChromaImg2ImgPipeline( ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + image_embeds = None negative_image_embeds = None if ip_adapter_image is not None or ip_adapter_image_embeds is not None: @@ -944,11 +944,6 @@ class ChromaImg2ImgPipeline( device, batch_size * num_images_per_prompt, ) - if self.do_classifier_free_guidance and image_embeds is not None and negative_image_embeds is not None: - image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) - - if image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -957,13 +952,14 @@ class ChromaImg2ImgPipeline( continue self._current_timestep = t - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) + timestep = t.expand(latents.shape[0]) + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents, timestep=timestep / 1000, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, @@ -974,8 +970,20 @@ class ChromaImg2ImgPipeline( )[0] if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + attention_mask=negative_attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype