mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add Flux Control to AutoPipeline (#10292)
This commit is contained in:
@@ -35,9 +35,12 @@ from .controlnet import (
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
from .flux import (
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
FluxControlPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxPipeline,
|
||||
@@ -125,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
|
||||
("auraflow", AuraFlowPipeline),
|
||||
("flux", FluxPipeline),
|
||||
("flux-control", FluxControlPipeline),
|
||||
("flux-controlnet", FluxControlNetPipeline),
|
||||
("lumina", LuminaText2ImgPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
@@ -150,6 +154,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
||||
("flux", FluxImg2ImgPipeline),
|
||||
("flux-controlnet", FluxControlNetImg2ImgPipeline),
|
||||
("flux-control", FluxControlImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -168,6 +173,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
||||
("flux", FluxInpaintPipeline),
|
||||
("flux-controlnet", FluxControlNetInpaintPipeline),
|
||||
("flux-control", FluxControlInpaintPipeline),
|
||||
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
|
||||
]
|
||||
)
|
||||
@@ -401,16 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
if "ControlPipeline" in orig_class_name:
|
||||
to_replace = "ControlPipeline"
|
||||
else:
|
||||
to_replace = "Pipeline"
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
|
||||
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
|
||||
else:
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
|
||||
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
|
||||
if "enable_pag" in kwargs:
|
||||
enable_pag = kwargs.pop("enable_pag")
|
||||
if enable_pag:
|
||||
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
|
||||
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
|
||||
|
||||
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
@@ -694,8 +704,14 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
|
||||
# the `orig_class_name` can be:
|
||||
# `- *Pipeline` (for regular text-to-image checkpoint)
|
||||
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
||||
# `- *Img2ImgPipeline` (for refiner checkpoint)
|
||||
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
|
||||
if "Img2Img" in orig_class_name:
|
||||
to_replace = "Img2ImgPipeline"
|
||||
elif "ControlPipeline" in orig_class_name:
|
||||
to_replace = "ControlPipeline"
|
||||
else:
|
||||
to_replace = "Pipeline"
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
@@ -707,6 +723,9 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
if enable_pag:
|
||||
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
||||
|
||||
if to_replace == "ControlPipeline":
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
|
||||
|
||||
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
@@ -994,8 +1013,14 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
|
||||
# The `orig_class_name`` can be:
|
||||
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
|
||||
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
||||
# - or *Pipeline (for regular text-to-image checkpoint)
|
||||
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
|
||||
if "Inpaint" in orig_class_name:
|
||||
to_replace = "InpaintPipeline"
|
||||
elif "ControlPipeline" in orig_class_name:
|
||||
to_replace = "ControlPipeline"
|
||||
else:
|
||||
to_replace = "Pipeline"
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
@@ -1006,6 +1031,8 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
enable_pag = kwargs.pop("enable_pag")
|
||||
if enable_pag:
|
||||
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
||||
if to_replace == "ControlPipeline":
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
|
||||
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
|
||||
Reference in New Issue
Block a user