diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e686c2de16..246765f760 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -25,22 +25,6 @@ import torch.nn.functional as F from huggingface_hub import hf_hub_download from torch import nn -from .models.attention_processor import ( - LORA_ATTENTION_PROCESSORS, - AttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - AttnProcessor, - AttnProcessor2_0, - CustomDiffusionAttnProcessor, - CustomDiffusionXFormersAttnProcessor, - LoRAAttnAddedKVProcessor, - LoRAAttnProcessor, - LoRAAttnProcessor2_0, - LoRALinearLayer, - LoRAXFormersAttnProcessor, - SlicedAttnAddedKVProcessor, - XFormersAttnProcessor, -) from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, @@ -83,6 +67,8 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensor class PatchedLoraProjection(nn.Module): def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): super().__init__() + from .models.attention_processor import LoRALinearLayer + self.regular_linear_layer = regular_linear_layer device = self.regular_linear_layer.weight.device @@ -231,6 +217,17 @@ class UNet2DConditionLoadersMixin: information. """ + from .models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + CustomDiffusionAttnProcessor, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + XFormersAttnProcessor, + ) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) @@ -423,6 +420,11 @@ class UNet2DConditionLoadersMixin: `DIFFUSERS_SAVE_MODE`. """ + from .models.attention_processor import ( + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, + ) + weight_name = weight_name or deprecate( "weights_name", "0.20.0", @@ -1317,6 +1319,17 @@ class LoraLoaderMixin: >>> ... ``` """ + from .models.attention_processor import ( + LORA_ATTENTION_PROCESSORS, + AttnProcessor, + AttnProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, + ) + unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()} if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS): diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index fd44fdb4eb..fdbe1dfaef 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -799,6 +799,9 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + text_model.load_state_dict(text_model_dict) return text_model @@ -960,6 +963,9 @@ def convert_open_clip_checkpoint( for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + text_model.load_state_dict(text_model_dict) return text_model