1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add ControlNetUnion to AutoPipeline from_pretrained (#10219)

This commit is contained in:
hlky
2024-12-16 20:25:08 +00:00
committed by GitHub
parent 2f023d7b84
commit 5ed761a6f2

View File

@@ -18,6 +18,7 @@ from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
@@ -28,6 +29,9 @@ from .controlnet import (
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetUnionImg2ImgPipeline,
StableDiffusionXLControlNetUnionInpaintPipeline,
StableDiffusionXLControlNetUnionPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import (
@@ -108,6 +112,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
("cascade", StableCascadeCombinedPipeline),
("lcm", LatentConsistencyModelPipeline),
@@ -139,6 +144,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
@@ -158,6 +164,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
@@ -396,7 +403,10 @@ class AutoPipelineForText2Image(ConfigMixin):
orig_class_name = config["_class_name"]
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
else:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
@@ -688,7 +698,10 @@ class AutoPipelineForImage2Image(ConfigMixin):
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
else:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
@@ -985,7 +998,10 @@ class AutoPipelineForInpainting(ConfigMixin):
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
else:
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag: