mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -217,6 +217,8 @@ class ChromaImg2ImgPipeline(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
@@ -852,10 +854,6 @@ class ChromaImg2ImgPipeline(
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
text_ids = torch.cat([negative_text_ids, text_ids], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
||||
@@ -913,7 +911,6 @@ class ChromaImg2ImgPipeline(
|
||||
dtype=latents.dtype,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
attention_mask = torch.cat([negative_attention_mask, attention_mask], dim=0)
|
||||
|
||||
# 6. Prepare image embeddings
|
||||
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
||||
@@ -928,6 +925,9 @@ class ChromaImg2ImgPipeline(
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
|
||||
image_embeds = None
|
||||
negative_image_embeds = None
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
@@ -944,11 +944,6 @@ class ChromaImg2ImgPipeline(
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
if self.do_classifier_free_guidance and image_embeds is not None and negative_image_embeds is not None:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
|
||||
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -957,13 +952,14 @@ class ChromaImg2ImgPipeline(
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
@@ -974,8 +970,20 @@ class ChromaImg2ImgPipeline(
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
|
||||
Reference in New Issue
Block a user