From ab7942174ad9debd5f3a41b1df54c1868e863e75 Mon Sep 17 00:00:00 2001 From: Edna <88869424+Ednaordinary@users.noreply.github.com> Date: Wed, 11 Jun 2025 19:57:31 -0600 Subject: [PATCH] use dn6 attn mask + fix true_cfg_scale --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 32135d2c21..de7e5deb20 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -232,9 +232,14 @@ class ChromaPipeline( prompt_embeds = self.text_encoder( text_input_ids.to(device), output_hidden_states=False, - #attention_mask=(text_inputs.attention_mask.to(device),), + attention_mask=text_inputs.attention_mask.to(device), )[0] + max_len = min(text_inputs.attention_mask.sum() + 1, max_sequence_length) + prompt_embeds = prompt_embeds[ + :, :max_len + ] + dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -554,7 +559,7 @@ class ChromaPipeline( instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is not greater than `1`). height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. @@ -794,7 +799,7 @@ class ChromaPipeline( joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] - noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype