1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2023-12-30 09:52:49 +00:00
parent 820313b8f4
commit efc6380615

View File

@@ -27,10 +27,7 @@ from ..utils import (
logging,
)
from .single_file_utils import (
create_controlnet_model,
create_paint_by_example_components,
create_scheduler,
create_stable_unclip_components,
create_text_encoders_and_tokenizers,
create_unet_model,
create_vae_model,
@@ -164,27 +161,17 @@ def build_component(
def build_additional_components(
pipeline_class_name,
original_config,
checkpoint,
checkpoint_path_or_dict,
**kwargs,
):
components = {}
load_safety_checker = kwargs.get("load_safety_checker", False)
local_files_only = kwargs.get("local_files_only", False)
if pipeline_class_name == ["StableUnCLIPPipeline", "StableUnCLIPImg2ImgPipeline"]:
stable_unclip_components = create_stable_unclip_components(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
)
components.update(stable_unclip_components)
if pipeline_class_name == "PaintByExamplePipeline":
paint_by_example_components = create_paint_by_example_components(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
)
return components.update(paint_by_example_components)
if pipeline_class_name in ["StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline"]:
if pipeline_class_name in [
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
]:
model_type = infer_model_type(pipeline_class_name, original_config)
is_refiner = model_type == "SDXL-Refiner"
components.update(
@@ -360,9 +347,7 @@ class FromSingleFileMixin:
additional_components = set(component_names - pipeline_components.keys())
if additional_components:
components = build_additional_components(
pipeline_name, original_config, checkpoint, pretrained_model_link_or_path, **kwargs
)
components = build_additional_components(pipeline_name, original_config, **kwargs)
if components:
pipeline_components.update(components)