mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SD3 LoRA] Fix list index out of range (#8584)
* fix * add check * key present is checked before * test case draft * aply suggestions * changed testing repo, back to old class * forgot docstring --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
@@ -1543,6 +1544,11 @@ class SD3LoraLoaderMixin:
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
|
||||
@@ -27,7 +27,7 @@ from diffusers import (
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, torch_device
|
||||
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@@ -287,3 +287,24 @@ class SD3LoRATests(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_sd3_lora(self):
|
||||
"""
|
||||
Test loading the loras that are saved with the diffusers and peft formats.
|
||||
Related PR: https://github.com/huggingface/diffusers/pull/8584
|
||||
"""
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
lora_model_id = "hf-internal-testing/tiny-sd3-loras"
|
||||
|
||||
lora_filename = "lora_diffusers_format.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
lora_filename = "lora_peft_format.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
Reference in New Issue
Block a user