From f69511ecc618330212e7148265e1c0323d2fa5cf Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 18 Jun 2024 21:09:30 +0530 Subject: [PATCH] [Single File Loading] Handle unexpected keys in CLIP models when `accelerate` isn't installed. (#8462) * update * update * update * update * update --------- Co-authored-by: YiYi Xu Co-authored-by: Sayak Paul --- src/diffusers/loaders/single_file_model.py | 18 +++++++++------- src/diffusers/loaders/single_file_utils.py | 25 +++++++++------------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index f576ecf262..f537a3f449 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -276,16 +276,18 @@ class FromOriginalModelMixin: if is_accelerate_available(): unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - if model._keys_to_ignore_on_load_unexpected is not None: - for pat in model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) else: - model.load_state_dict(diffusers_format_checkpoint) + _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) + + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) if torch_dtype is not None: model.to(torch_dtype) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 98fef894ee..e0a660020a 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -1268,8 +1268,6 @@ def convert_open_clip_checkpoint( else: text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") - keys = list(checkpoint.keys()) keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE @@ -1318,9 +1316,6 @@ def convert_open_clip_checkpoint( else: text_model_dict[diffusers_key] = checkpoint.get(key) - if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): - text_model_dict.pop("text_model.embeddings.position_ids", None) - return text_model_dict @@ -1414,17 +1409,17 @@ def create_diffusers_clip_model_from_ldm( if is_accelerate_available(): unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - if model._keys_to_ignore_on_load_unexpected is not None: - for pat in model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) - else: - model.load_state_dict(diffusers_format_checkpoint) + _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) + + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) if torch_dtype is not None: model.to(torch_dtype)