mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
feat: support DoRA LoRA from community (#7371)
* feat: support dora loras from community * safe-guard dora operations under peft version. * pop use_dora when False * make dora lora from kohya work. * fix: kohya conversion utils. * add a fast test for DoRA compatibility.. * add a nightly test.
This commit is contained in:
@@ -36,6 +36,7 @@ from ..utils import (
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
@@ -113,7 +114,7 @@ class LoraLoaderMixin:
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
@@ -451,6 +452,15 @@ class LoraLoaderMixin:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -572,6 +582,15 @@ class LoraLoaderMixin:
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -654,6 +673,13 @@ class LoraLoaderMixin:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -1243,7 +1269,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import re
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils import is_peft_version, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -128,6 +128,15 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
te_state_dict = {}
|
||||
te2_state_dict = {}
|
||||
network_alphas = {}
|
||||
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
||||
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
||||
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
||||
|
||||
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
|
||||
# every down weight has a corresponding up weight and potentially an alpha weight
|
||||
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
||||
@@ -198,6 +207,12 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
if is_unet_dora_lora:
|
||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||
unet_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
||||
@@ -229,6 +244,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
dora_scale_key_to_replace_te = (
|
||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||
)
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
te2_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
# Rename the alphas so that they can be mapped appropriately.
|
||||
if lora_name_alpha in state_dict:
|
||||
alpha = state_dict.pop(lora_name_alpha).item()
|
||||
|
||||
@@ -69,6 +69,7 @@ from .import_utils import (
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_scipy_available,
|
||||
is_tensorboard_available,
|
||||
is_torch_available,
|
||||
|
||||
@@ -628,6 +628,20 @@ def is_accelerate_version(operation: str, version: str):
|
||||
return compare_versions(parse(_accelerate_version), operation, version)
|
||||
|
||||
|
||||
def is_peft_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
Compares the current PEFT version to a given reference with an operation.
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _peft_version:
|
||||
return False
|
||||
return compare_versions(parse(_peft_version), operation, version)
|
||||
|
||||
|
||||
def is_k_diffusion_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -171,6 +171,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
||||
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
||||
|
||||
lora_config_kwargs = {
|
||||
"r": r,
|
||||
@@ -178,6 +179,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
||||
"rank_pattern": rank_pattern,
|
||||
"alpha_pattern": alpha_pattern,
|
||||
"target_modules": target_modules,
|
||||
"use_dora": use_dora,
|
||||
}
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ UNET_TO_DIFFUSERS = {
|
||||
".to_v_lora.up": ".to_v.lora_B",
|
||||
".lora.up": ".lora_B",
|
||||
".lora.down": ".lora_A",
|
||||
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
|
||||
}
|
||||
|
||||
|
||||
@@ -104,6 +105,10 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
|
||||
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
|
||||
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
|
||||
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
||||
".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
|
||||
".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
|
||||
".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
|
||||
".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
|
||||
}
|
||||
|
||||
PEFT_TO_KOHYA_SS = {
|
||||
@@ -315,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
|
||||
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
|
||||
elif "unet" in kohya_key:
|
||||
kohya_key = kohya_key.replace("unet", "lora_unet")
|
||||
elif "lora_magnitude_vector" in kohya_key:
|
||||
kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
|
||||
|
||||
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
|
||||
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
|
||||
kohya_ss_state_dict[kohya_key] = weight
|
||||
|
||||
@@ -630,3 +630,21 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
expected_slice_scale = np.array([0.5456, 0.5466, 0.5487, 0.5458, 0.5469, 0.5454, 0.5446, 0.5479, 0.5487])
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@nightly
|
||||
def test_integration_logits_for_dora_lora(self):
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
images = pipeline(
|
||||
"photo of ohwx dog",
|
||||
num_inference_steps=10,
|
||||
generator=torch.manual_seed(0),
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
predicted_slice = images[0, -3:, -3:, -1].flatten()
|
||||
expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516])
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -72,7 +72,7 @@ class PeftLoraLoaderMixinTests:
|
||||
unet_kwargs = None
|
||||
vae_kwargs = None
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None):
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
|
||||
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
|
||||
rank = 4
|
||||
|
||||
@@ -96,10 +96,15 @@ class PeftLoraLoaderMixinTests:
|
||||
lora_alpha=rank,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
@@ -1074,6 +1079,37 @@ class PeftLoraLoaderMixinTests:
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
|
||||
@require_peft_version_greater(peft_version="0.9.0")
|
||||
def test_simple_inference_with_dora(self):
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls, use_dora=True)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_dora_lora.shape == (1, 64, 64, 3))
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
pipe.unet.add_adapter(unet_lora_config)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
|
||||
"DoRA lora should change the output",
|
||||
)
|
||||
|
||||
@unittest.skip("This is failing for now - need to investigate")
|
||||
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user