mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -153,6 +153,9 @@ def build_component(
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
|
||||
load_safety_checker = kwargs.get("load_safety_checker", False)
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
|
||||
if component_name == "unet":
|
||||
unet_components = create_unet_model(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
@@ -177,29 +180,7 @@ def build_component(
|
||||
)
|
||||
return text_encoder_components
|
||||
|
||||
return
|
||||
|
||||
|
||||
def build_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
**kwargs,
|
||||
):
|
||||
components = {}
|
||||
load_safety_checker = kwargs.get("load_safety_checker", False)
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(pipeline_class_name, original_config)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
"requires_aesthetics_score": is_refiner,
|
||||
"force_zeros_for_empty_prompt": False if is_refiner else True,
|
||||
}
|
||||
)
|
||||
|
||||
if pipeline_class_name in SAFETY_CHECKER_PIPELINES:
|
||||
if component_name == "safety_checker":
|
||||
if load_safety_checker:
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
@@ -211,10 +192,24 @@ def build_additional_components(
|
||||
safety_checker = None
|
||||
feature_extractor = None
|
||||
|
||||
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def build_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
**kwargs,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(pipeline_class_name, original_config)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
"safety_checker": safety_checker,
|
||||
"feature_extractor": feature_extractor,
|
||||
"requires_aesthetics_score": is_refiner,
|
||||
"force_zeros_for_empty_prompt": False if is_refiner else True,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def infer_model_type(pipeline_class_name, original_config, model_type=None):
|
||||
def infer_model_type(pipeline_class_name, original_config, model_type=None, **kwargs):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
|
||||
Reference in New Issue
Block a user