From cd6ca9df2987c000b28e13b19bd4eec3ef3c914b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 21 Nov 2024 13:02:31 +0530 Subject: [PATCH] Fix prepare latent image ids and vae sample generators for flux (#9981) * fix * update expected slice --- src/diffusers/pipelines/flux/pipeline_flux.py | 2 +- .../flux/pipeline_flux_controlnet.py | 20 ++++++++++++++++--- ...pipeline_flux_controlnet_image_to_image.py | 4 ++-- .../pipeline_flux_controlnet_inpainting.py | 4 ++-- .../controlnet_flux/test_controlnet_flux.py | 2 +- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 12996f3f3e..e0add1e60c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -513,7 +513,7 @@ class FluxPipeline( shape = (batch_size, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 904173852e..654bc41af4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -97,6 +97,20 @@ def calculate_shift( return mu +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -512,7 +526,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF shape = (batch_size, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: @@ -772,7 +786,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True if self.controlnet.input_hint_block is None: # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack @@ -810,7 +824,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF if self.controlnet.nets[0].input_hint_block is None: # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 5d65df0b76..6ab34d8a9c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -801,7 +801,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ) height, width = control_image.shape[-2:] - control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor height_control_image, width_control_image = control_image.shape[2:] @@ -832,7 +832,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ) height, width = control_image_.shape[-2:] - control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor height_control_image, width_control_image = control_image_.shape[2:] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 5d5c8f7376..d81cffaca3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -942,7 +942,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True if self.controlnet.input_hint_block is None: # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack @@ -979,7 +979,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From if self.controlnet.nets[0].input_hint_block is None: # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor # pack diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index ee3984dcd3..8202424e7f 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -170,7 +170,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin): assert image.shape == (1, 32, 32, 3) expected_slice = np.array( - [0.7348633, 0.41333008, 0.6621094, 0.5444336, 0.47607422, 0.5859375, 0.44677734, 0.4506836, 0.40454102] + [0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156] ) assert (