1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2025-06-17 09:48:50 +02:00
parent 414de99853
commit acc1a49250

View File

@@ -217,6 +217,8 @@ class ChromaImg2ImgPipeline(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
@@ -852,10 +854,6 @@ class ChromaImg2ImgPipeline(
lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
text_ids = torch.cat([negative_text_ids, text_ids], dim=0)
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
@@ -913,7 +911,6 @@ class ChromaImg2ImgPipeline(
dtype=latents.dtype,
attention_mask=negative_prompt_attention_mask,
)
attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0)
# 6. Prepare image embeddings
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
@@ -928,6 +925,9 @@ class ChromaImg2ImgPipeline(
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
@@ -944,11 +944,6 @@ class ChromaImg2ImgPipeline(
device,
batch_size * num_images_per_prompt,
)
if self.do_classifier_free_guidance and image_embeds is not None and negative_image_embeds is not None:
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -957,13 +952,14 @@ class ChromaImg2ImgPipeline(
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
timestep = t.expand(latents.shape[0])
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
noise_pred = self.transformer(
hidden_states=latent_model_input,
hidden_states=latents,
timestep=timestep / 1000,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
@@ -974,8 +970,20 @@ class ChromaImg2ImgPipeline(
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
noise_pred_uncond = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
attention_mask=negative_attention_mask,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype