diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index b47a67dc77..0274c3e5d0 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -304,6 +304,7 @@ class ChromaPipeline( return prompt_embeds, text_ids + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -315,6 +316,7 @@ class ChromaPipeline( image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt ): @@ -395,6 +397,7 @@ class ChromaPipeline( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latent_image_ids @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) @@ -409,6 +412,7 @@ class ChromaPipeline( return latent_image_ids.to(device=device, dtype=dtype) + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) @@ -417,6 +421,7 @@ class ChromaPipeline( return latents + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents @staticmethod def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape @@ -462,6 +467,8 @@ class ChromaPipeline( """ self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, batch_size,