mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Lora] fix lora fuse unfuse (#5003)
* fix lora fuse unfuse * add same changes to loaders.py * add test --------- Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
This commit is contained in:
committed by
GitHub
parent
324aef6d14
commit
b47f5115da
@@ -121,7 +121,7 @@ class PatchedLoraProjection(nn.Module):
|
||||
self.lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
||||
return
|
||||
|
||||
fused_weight = self.regular_linear_layer.weight.data
|
||||
|
||||
@@ -139,7 +139,7 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
||||
return
|
||||
|
||||
fused_weight = self.weight.data
|
||||
@@ -204,7 +204,7 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
||||
return
|
||||
|
||||
fused_weight = self.weight.data
|
||||
|
||||
@@ -43,7 +43,7 @@ from diffusers.models.attention_processor import (
|
||||
LoRAAttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, slow, torch_device
|
||||
|
||||
|
||||
def create_unet_lora_layers(unet: nn.Module):
|
||||
@@ -1497,3 +1497,41 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
@nightly
|
||||
def test_sequential_fuse_unfuse(self):
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
|
||||
# 1. round
|
||||
pipe.load_lora_weights("Pclanglais/TintinIA")
|
||||
pipe.fuse_lora()
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
image_slice = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# 2. round
|
||||
pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style")
|
||||
pipe.fuse_lora()
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# 3. round
|
||||
pipe.load_lora_weights("ostris/crayon_style_lora_sdxl")
|
||||
pipe.fuse_lora()
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# 4. back to 1st round
|
||||
pipe.load_lora_weights("Pclanglais/TintinIA")
|
||||
pipe.fuse_lora()
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
images_2 = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
image_slice_2 = images_2[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user