mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] fix cross_attention_kwargs problems and tighten tests (#7388)
* debugging * let's see the numbers * let's see the numbers * let's see the numbers * restrict tolerance. * increase inference steps. * shallow copy of cross_attentionkwargs * remove print
This commit is contained in:
@@ -1178,6 +1178,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
||||
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
||||
if cross_attention_kwargs is not None:
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
@@ -158,7 +158,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 2,
|
||||
"num_inference_steps": 5,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
@@ -589,7 +589,7 @@ class PeftLoraLoaderMixinTests:
|
||||
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
|
||||
).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
||||
not np.allclose(output_lora, output_lora_scale, atol=1e-4, rtol=1e-4),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
|
||||
@@ -1300,6 +1300,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
self.assertTrue(
|
||||
self.check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in UNet",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
self.check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder 2",
|
||||
|
||||
Reference in New Issue
Block a user