mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -55,31 +55,6 @@ REFINER_PIPELINES = [
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
]
|
||||
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"ControlNetModel": "create_controlnet_model",
|
||||
"AutoencoderKL": "create_vae_model",
|
||||
"UNet2DConditionModel": "create_unet_model",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def extract_pipeline_component_names(pipeline_class):
|
||||
components = inspect.signature(pipeline_class).parameters.keys()
|
||||
return components
|
||||
|
||||
|
||||
def check_valid_url(pretrained_model_link_or_path):
|
||||
# check if url prefix is valid
|
||||
# remove huggingface url prefix from model path
|
||||
has_valid_url_prefix = False
|
||||
for prefix in VALID_URL_PREFIXES:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
has_valid_url_prefix = True
|
||||
|
||||
return has_valid_url_prefix, pretrained_model_link_or_path
|
||||
|
||||
|
||||
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
||||
pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
|
||||
@@ -174,8 +149,9 @@ def set_additional_components(
|
||||
**kwargs,
|
||||
):
|
||||
components = {}
|
||||
model_type = kwargs.get("model_type", None)
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(pipeline_class_name, original_config)
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
@@ -206,8 +182,7 @@ class FromSingleFileMixin:
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -291,6 +266,7 @@ class FromSingleFileMixin:
|
||||
)
|
||||
checkpoint = load_state_dict(checkpoint_path)
|
||||
|
||||
# some checkpoints contain the model state dict under a "state_dict" key
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
@@ -340,11 +316,9 @@ class FromSingleFileMixin:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set(optional_kwargs - init_kwargs.keys())
|
||||
additional_components = set_additional_components(class_name, original_config, **kwargs)
|
||||
if additional_components:
|
||||
components = set_additional_components(class_name, original_config, **kwargs)
|
||||
if components:
|
||||
init_kwargs.update(components)
|
||||
init_kwargs.update(additional_components)
|
||||
|
||||
init_kwargs.update(passed_pipe_kwargs)
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
@@ -792,7 +792,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if (
|
||||
expected_add_embed_dim > passed_add_embed_dim
|
||||
and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
|
||||
|
||||
Reference in New Issue
Block a user