1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove fluxcontrolmixin

This commit is contained in:
sayakpaul
2025-12-08 12:19:50 +05:30
parent cf3053b565
commit be586607de
7 changed files with 47 additions and 49 deletions

View File

@@ -30,7 +30,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_timesteps
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput
@@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """
class FluxControlPipeline(
DiffusionPipeline,
FluxControlMixin,
FluxMixin,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
@@ -235,6 +235,41 @@ class FluxControlPipeline(
return latents, latent_image_ids
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(

View File

@@ -26,7 +26,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput
@@ -80,7 +80,7 @@ EXAMPLE_DOC_STRING = """
"""
class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
r"""
The Flux pipeline for image inpainting.

View File

@@ -35,7 +35,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput
@@ -108,7 +108,7 @@ EXAMPLE_DOC_STRING = """
class FluxControlInpaintPipeline(
DiffusionPipeline,
FluxControlMixin,
FluxMixin,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,

View File

@@ -34,7 +34,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput
@@ -80,7 +80,7 @@ EXAMPLE_DOC_STRING = """
class FluxControlNetPipeline(
DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin
DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin
):
r"""
The Flux pipeline for text-to-image generation.

View File

@@ -18,7 +18,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput
@@ -74,7 +74,7 @@ EXAMPLE_DOC_STRING = """
"""
class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
r"""
The Flux controlnet pipeline for image-to-image generation.

View File

@@ -19,7 +19,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxControlMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput
@@ -76,7 +76,7 @@ EXAMPLE_DOC_STRING = """
"""
class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxControlMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxMixin, FluxLoraLoaderMixin, FromSingleFileMixin):
r"""
The Flux controlnet pipeline for inpainting.

View File

@@ -394,40 +394,3 @@ class FluxMixin:
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
class FluxControlMixin(FluxMixin):
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image