mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[from_single_file] Fix circular import (#4259)
* up * finish * fix final
This commit is contained in:
committed by
GitHub
parent
5ef6b8fa53
commit
ebfe343149
@@ -25,22 +25,6 @@ import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from torch import nn
|
||||
|
||||
from .models.attention_processor import (
|
||||
LORA_ATTENTION_PROCESSORS,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRALinearLayer,
|
||||
LoRAXFormersAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
@@ -83,6 +67,8 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensor
|
||||
class PatchedLoraProjection(nn.Module):
|
||||
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
||||
super().__init__()
|
||||
from .models.attention_processor import LoRALinearLayer
|
||||
|
||||
self.regular_linear_layer = regular_linear_layer
|
||||
|
||||
device = self.regular_linear_layer.weight.device
|
||||
@@ -231,6 +217,17 @@ class UNet2DConditionLoadersMixin:
|
||||
information.
|
||||
|
||||
"""
|
||||
from .models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
CustomDiffusionAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -423,6 +420,11 @@ class UNet2DConditionLoadersMixin:
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
|
||||
"""
|
||||
from .models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
weight_name = weight_name or deprecate(
|
||||
"weights_name",
|
||||
"0.20.0",
|
||||
@@ -1317,6 +1319,17 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
from .models.attention_processor import (
|
||||
LORA_ATTENTION_PROCESSORS,
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
|
||||
unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()}
|
||||
|
||||
if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS):
|
||||
|
||||
@@ -799,6 +799,9 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
else:
|
||||
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
||||
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
@@ -960,6 +963,9 @@ def convert_open_clip_checkpoint(
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
else:
|
||||
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
||||
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
Reference in New Issue
Block a user