mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] support LyCORIS (#5102)
* better condition. * debugging * how about now? * how about now? * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * support for lycoris. * style * add: lycoris test * fix from_pretrained call. * fix assertion values.
This commit is contained in:
@@ -1878,7 +1878,7 @@ class LoraLoaderMixin:
|
||||
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
||||
|
||||
# SDXL specificity.
|
||||
if "emb" in diffusers_name:
|
||||
if "emb" in diffusers_name and "time" not in diffusers_name:
|
||||
pattern = r"\.\d+(?=\D*$)"
|
||||
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
||||
if ".in." in diffusers_name:
|
||||
@@ -1890,6 +1890,13 @@ class LoraLoaderMixin:
|
||||
if "skip" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
||||
|
||||
# LyCORIS specificity.
|
||||
if "time" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
||||
if "conv.shortcut" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
||||
|
||||
# General coverage.
|
||||
if "transformer_blocks" in diffusers_name:
|
||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from .activations import get_activation
|
||||
from .lora import LoRACompatibleLinear
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -166,7 +167,7 @@ class TimestepEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
@@ -179,7 +180,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
|
||||
@@ -1876,6 +1876,25 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
def test_lycoris(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16"
|
||||
).to(torch_device)
|
||||
lora_model_id = "hf-internal-testing/edgLycorisMugler-light"
|
||||
lora_filename = "edgLycorisMugler-light.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
def test_a1111_with_model_cpu_offload(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user