mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add ControlNetUnion to AutoPipeline from_pretrained (#10219)
This commit is contained in:
@@ -18,6 +18,7 @@ from collections import OrderedDict
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..models.controlnets import ControlNetUnionModel
|
||||
from ..utils import is_sentencepiece_available
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .cogview3 import CogView3PlusPipeline
|
||||
@@ -28,6 +29,9 @@ from .controlnet import (
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLControlNetUnionImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetUnionInpaintPipeline,
|
||||
StableDiffusionXLControlNetUnionPipeline,
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
from .flux import (
|
||||
@@ -108,6 +112,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("kandinsky3", Kandinsky3Pipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
|
||||
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
|
||||
("wuerstchen", WuerstchenCombinedPipeline),
|
||||
("cascade", StableCascadeCombinedPipeline),
|
||||
("lcm", LatentConsistencyModelPipeline),
|
||||
@@ -139,6 +144,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
||||
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
|
||||
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
||||
@@ -158,6 +164,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
|
||||
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
||||
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
||||
("flux", FluxInpaintPipeline),
|
||||
("flux-controlnet", FluxControlNetInpaintPipeline),
|
||||
@@ -396,7 +403,10 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
|
||||
else:
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
|
||||
if "enable_pag" in kwargs:
|
||||
enable_pag = kwargs.pop("enable_pag")
|
||||
if enable_pag:
|
||||
@@ -688,7 +698,10 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
||||
else:
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
||||
if "enable_pag" in kwargs:
|
||||
enable_pag = kwargs.pop("enable_pag")
|
||||
if enable_pag:
|
||||
@@ -985,7 +998,10 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
||||
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
||||
else:
|
||||
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
||||
if "enable_pag" in kwargs:
|
||||
enable_pag = kwargs.pop("enable_pag")
|
||||
if enable_pag:
|
||||
|
||||
Reference in New Issue
Block a user