1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add Flux Control to AutoPipeline (#10292)

This commit is contained in:
hlky
2024-12-19 08:28:56 +00:00
committed by GitHub
parent f781b8c30c
commit 4450d26b63

View File

@@ -35,9 +35,12 @@ from .controlnet import (
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
FluxControlPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
@@ -125,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline),
@@ -150,6 +154,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
]
)
@@ -168,6 +173,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
]
)
@@ -401,16 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
if "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
else:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
@@ -694,8 +704,14 @@ class AutoPipelineForImage2Image(ConfigMixin):
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
if "Img2Img" in orig_class_name:
to_replace = "Img2ImgPipeline"
elif "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -707,6 +723,9 @@ class AutoPipelineForImage2Image(ConfigMixin):
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
if to_replace == "ControlPipeline":
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -994,8 +1013,14 @@ class AutoPipelineForInpainting(ConfigMixin):
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
# - `*ControlPipeline` (for Flux tools specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
if "Inpaint" in orig_class_name:
to_replace = "InpaintPipeline"
elif "ControlPipeline" in orig_class_name:
to_replace = "ControlPipeline"
else:
to_replace = "Pipeline"
if "controlnet" in kwargs:
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -1006,6 +1031,8 @@ class AutoPipelineForInpainting(ConfigMixin):
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
if to_replace == "ControlPipeline":
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}