From 7235805e752641bb30dc4cbbb881c3c24addfc29 Mon Sep 17 00:00:00 2001 From: Edna <88869424+Ednaordinary@users.noreply.github.com> Date: Thu, 12 Jun 2025 03:40:52 -0600 Subject: [PATCH] Revert cond + uncond batching --- .../pipelines/chroma/pipeline_chroma.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 0274c3e5d0..d20ae43b36 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -694,9 +694,6 @@ class ChromaPipeline( max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) - - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -773,13 +770,11 @@ class ChromaPipeline( if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents, timestep=timestep / 1000, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, @@ -791,8 +786,16 @@ class ChromaPipeline( if self.do_classifier_free_guidance: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype