1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update # Copied from statements

This commit is contained in:
Edna
2025-06-12 03:27:35 -06:00
committed by GitHub
parent fe5af79a19
commit 6a0db55af8

View File

@@ -304,6 +304,7 @@ class ChromaPipeline(
return prompt_embeds, text_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
@@ -315,6 +316,7 @@ class ChromaPipeline(
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
@@ -395,6 +397,7 @@ class ChromaPipeline(
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latent_image_ids
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
@@ -409,6 +412,7 @@ class ChromaPipeline(
return latent_image_ids.to(device=device, dtype=dtype)
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
@@ -417,6 +421,7 @@ class ChromaPipeline(
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
@@ -462,6 +467,8 @@ class ChromaPipeline(
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
def prepare_latents(
self,
batch_size,