1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2023-12-30 11:28:21 +00:00
parent afa62e6fa8
commit e033f9f608
2 changed files with 21 additions and 26 deletions

View File

@@ -153,6 +153,9 @@ def build_component(
if component_name in pipeline_components:
return {}
load_safety_checker = kwargs.get("load_safety_checker", False)
local_files_only = kwargs.get("local_files_only", False)
if component_name == "unet":
unet_components = create_unet_model(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
@@ -177,29 +180,7 @@ def build_component(
)
return text_encoder_components
return
def build_additional_components(
pipeline_class_name,
original_config,
**kwargs,
):
components = {}
load_safety_checker = kwargs.get("load_safety_checker", False)
local_files_only = kwargs.get("local_files_only", False)
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(pipeline_class_name, original_config)
is_refiner = model_type == "SDXL-Refiner"
components.update(
{
"requires_aesthetics_score": is_refiner,
"force_zeros_for_empty_prompt": False if is_refiner else True,
}
)
if pipeline_class_name in SAFETY_CHECKER_PIPELINES:
if component_name == "safety_checker":
if load_safety_checker:
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
@@ -211,10 +192,24 @@ def build_additional_components(
safety_checker = None
feature_extractor = None
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
return
def build_additional_components(
pipeline_class_name,
original_config,
**kwargs,
):
components = {}
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(pipeline_class_name, original_config)
is_refiner = model_type == "SDXL-Refiner"
components.update(
{
"safety_checker": safety_checker,
"feature_extractor": feature_extractor,
"requires_aesthetics_score": is_refiner,
"force_zeros_for_empty_prompt": False if is_refiner else True,
}
)

View File

@@ -178,7 +178,7 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True)
return checkpoint
def infer_model_type(pipeline_class_name, original_config, model_type=None):
def infer_model_type(pipeline_class_name, original_config, model_type=None, **kwargs):
if model_type is not None:
return model_type