1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[from_single_file] Fix circular import (#4259)

* up

* finish

* fix final
This commit is contained in:
Patrick von Platen
2023-07-25 14:30:39 +02:00
committed by GitHub
parent 5ef6b8fa53
commit ebfe343149
2 changed files with 35 additions and 16 deletions

View File

@@ -25,22 +25,6 @@ import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from torch import nn
from .models.attention_processor import (
LORA_ATTENTION_PROCESSORS,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
AttnProcessor,
AttnProcessor2_0,
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRALinearLayer,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
)
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
@@ -83,6 +67,8 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensor
class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__()
from .models.attention_processor import LoRALinearLayer
self.regular_linear_layer = regular_linear_layer
device = self.regular_linear_layer.weight.device
@@ -231,6 +217,17 @@ class UNet2DConditionLoadersMixin:
information.
"""
from .models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
CustomDiffusionAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
@@ -423,6 +420,11 @@ class UNet2DConditionLoadersMixin:
`DIFFUSERS_SAVE_MODE`.
"""
from .models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
)
weight_name = weight_name or deprecate(
"weights_name",
"0.20.0",
@@ -1317,6 +1319,17 @@ class LoraLoaderMixin:
>>> ...
```
"""
from .models.attention_processor import (
LORA_ATTENTION_PROCESSORS,
AttnProcessor,
AttnProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()}
if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS):

View File

@@ -799,6 +799,9 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict)
return text_model
@@ -960,6 +963,9 @@ def convert_open_clip_checkpoint(
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict)
return text_model