1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[SD3 LoRA] Fix list index out of range (#8584)

* fix

* add check

* key present is checked before

* test case draft

* aply suggestions

* changed testing repo, back to old class

* forgot docstring

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Álvaro Somoza
2024-06-21 11:47:34 -04:00
committed by GitHub
parent 8eb17315c8
commit e7b9a0762b
2 changed files with 28 additions and 1 deletions

View File

@@ -30,6 +30,7 @@ from ..utils import (
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
@@ -1543,6 +1544,11 @@ class SD3LoraLoaderMixin:
}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."

View File

@@ -27,7 +27,7 @@ from diffusers import (
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, torch_device
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
if is_peft_available():
@@ -287,3 +287,24 @@ class SD3LoRATests(unittest.TestCase):
self.assertTrue(
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
@require_torch_gpu
def test_sd3_lora(self):
"""
Test loading the loras that are saved with the diffusers and peft formats.
Related PR: https://github.com/huggingface/diffusers/pull/8584
"""
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
lora_model_id = "hf-internal-testing/tiny-sd3-loras"
lora_filename = "lora_diffusers_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.unload_lora_weights()
lora_filename = "lora_peft_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)