diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index dbeb0bce9e..6d04c21ee5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -30,7 +30,7 @@ from ...utils import ( ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_timesteps +from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps from .pipeline_output import FluxPipelineOutput @@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """ class FluxControlPipeline( DiffusionPipeline, - FluxControlMixin, + FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, @@ -235,6 +235,41 @@ class FluxControlPipeline( return latents, latent_image_ids + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index f0269f5eca..068ab7132f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -26,7 +26,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps +from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput @@ -80,7 +80,7 @@ EXAMPLE_DOC_STRING = """ """ -class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin): r""" The Flux pipeline for image inpainting. diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index d475c54f73..1bf2f343c3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -35,7 +35,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps +from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput @@ -108,7 +108,7 @@ EXAMPLE_DOC_STRING = """ class FluxControlInpaintPipeline( DiffusionPipeline, - FluxControlMixin, + FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 9209a51e7f..b26f3dcd12 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -34,7 +34,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps +from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput @@ -80,7 +80,7 @@ EXAMPLE_DOC_STRING = """ class FluxControlNetPipeline( - DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin + DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin ): r""" The Flux pipeline for text-to-image generation. diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 0cd4a051ae..f246c8be10 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -18,7 +18,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps +from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput @@ -74,7 +74,7 @@ EXAMPLE_DOC_STRING = """ """ -class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin): r""" The Flux controlnet pipeline for image-to-image generation. diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 9fb7f278f3..b59b312239 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -19,7 +19,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps +from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps from .pipeline_output import FluxPipelineOutput @@ -76,7 +76,7 @@ EXAMPLE_DOC_STRING = """ """ -class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin): r""" The Flux controlnet pipeline for inpainting. diff --git a/src/diffusers/pipelines/flux/pipeline_flux_utils.py b/src/diffusers/pipelines/flux/pipeline_flux_utils.py index b4cbf58a22..94ef7205ec 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_utils.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_utils.py @@ -394,40 +394,3 @@ class FluxMixin: prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds - - -class FluxControlMixin(FluxMixin): - # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image