From e033f9f6084c4bedd6530fc20ac534b15d74d086 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 30 Dec 2023 11:28:21 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file.py | 45 ++++++++++------------ src/diffusers/loaders/single_file_utils.py | 2 +- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 82d29ce73f..fd5a955316 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -153,6 +153,9 @@ def build_component( if component_name in pipeline_components: return {} + load_safety_checker = kwargs.get("load_safety_checker", False) + local_files_only = kwargs.get("local_files_only", False) + if component_name == "unet": unet_components = create_unet_model( pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs @@ -177,29 +180,7 @@ def build_component( ) return text_encoder_components - return - - -def build_additional_components( - pipeline_class_name, - original_config, - **kwargs, -): - components = {} - load_safety_checker = kwargs.get("load_safety_checker", False) - local_files_only = kwargs.get("local_files_only", False) - - if pipeline_class_name in REFINER_PIPELINES: - model_type = infer_model_type(pipeline_class_name, original_config) - is_refiner = model_type == "SDXL-Refiner" - components.update( - { - "requires_aesthetics_score": is_refiner, - "force_zeros_for_empty_prompt": False if is_refiner else True, - } - ) - - if pipeline_class_name in SAFETY_CHECKER_PIPELINES: + if component_name == "safety_checker": if load_safety_checker: safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only @@ -211,10 +192,24 @@ def build_additional_components( safety_checker = None feature_extractor = None + return {"safety_checker": safety_checker, "feature_extractor": feature_extractor} + + return + + +def build_additional_components( + pipeline_class_name, + original_config, + **kwargs, +): + components = {} + if pipeline_class_name in REFINER_PIPELINES: + model_type = infer_model_type(pipeline_class_name, original_config) + is_refiner = model_type == "SDXL-Refiner" components.update( { - "safety_checker": safety_checker, - "feature_extractor": feature_extractor, + "requires_aesthetics_score": is_refiner, + "force_zeros_for_empty_prompt": False if is_refiner else True, } ) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 9ee22ad43f..f6e94ebc25 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -178,7 +178,7 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True) return checkpoint -def infer_model_type(pipeline_class_name, original_config, model_type=None): +def infer_model_type(pipeline_class_name, original_config, model_type=None, **kwargs): if model_type is not None: return model_type