1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2024-01-19 06:13:20 +00:00
parent ffde1235fc
commit fd2ec36369
2 changed files with 6 additions and 33 deletions

View File

@@ -55,31 +55,6 @@ REFINER_PIPELINES = [
"StableDiffusionXLControlNetImg2ImgPipeline",
]
LOADABLE_CLASSES = {
"diffusers": {
"ControlNetModel": "create_controlnet_model",
"AutoencoderKL": "create_vae_model",
"UNet2DConditionModel": "create_unet_model",
}
}
def extract_pipeline_component_names(pipeline_class):
components = inspect.signature(pipeline_class).parameters.keys()
return components
def check_valid_url(pretrained_model_link_or_path):
# check if url prefix is valid
# remove huggingface url prefix from model path
has_valid_url_prefix = False
for prefix in VALID_URL_PREFIXES:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
has_valid_url_prefix = True
return has_valid_url_prefix, pretrained_model_link_or_path
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
@@ -174,8 +149,9 @@ def set_additional_components(
**kwargs,
):
components = {}
model_type = kwargs.get("model_type", None)
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(pipeline_class_name, original_config)
model_type = infer_model_type(original_config, model_type=model_type)
is_refiner = model_type == "SDXL-Refiner"
components.update(
{
@@ -206,8 +182,7 @@ class FromSingleFileMixin:
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -291,6 +266,7 @@ class FromSingleFileMixin:
)
checkpoint = load_state_dict(checkpoint_path)
# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
@@ -340,11 +316,9 @@ class FromSingleFileMixin:
continue
init_kwargs.update(components)
additional_components = set(optional_kwargs - init_kwargs.keys())
additional_components = set_additional_components(class_name, original_config, **kwargs)
if additional_components:
components = set_additional_components(class_name, original_config, **kwargs)
if components:
init_kwargs.update(components)
init_kwargs.update(additional_components)
init_kwargs.update(passed_pipe_kwargs)
pipe = pipeline_class(**init_kwargs)

View File

@@ -792,7 +792,6 @@ class StableDiffusionXLImg2ImgPipeline(
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
if (
expected_add_embed_dim > passed_add_embed_dim
and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim