From 699dfb084ca12be9c1cd34ad7b6d0fffb9f0ebff Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 26 Mar 2024 09:37:33 +0530 Subject: [PATCH] feat: support DoRA LoRA from community (#7371) * feat: support dora loras from community * safe-guard dora operations under peft version. * pop use_dora when False * make dora lora from kohya work. * fix: kohya conversion utils. * add a fast test for DoRA compatibility.. * add a nightly test. --- src/diffusers/loaders/lora.py | 30 +++++++++++++- .../loaders/lora_conversion_utils.py | 30 +++++++++++++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 14 +++++++ src/diffusers/utils/peft_utils.py | 2 + src/diffusers/utils/state_dict_utils.py | 8 ++++ tests/lora/test_lora_layers_sdxl.py | 18 +++++++++ tests/lora/utils.py | 40 ++++++++++++++++++- 8 files changed, 138 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 973fcbcf12..aa53563fd3 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -36,6 +36,7 @@ from ..utils import ( get_adapter_name, get_peft_kwargs, is_accelerate_available, + is_peft_version, is_transformers_available, logging, recurse_remove_peft_layers, @@ -113,7 +114,7 @@ class LoraLoaderMixin: # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") @@ -451,6 +452,15 @@ class LoraLoaderMixin: rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -572,6 +582,15 @@ class LoraLoaderMixin: } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -654,6 +673,13 @@ class LoraLoaderMixin: rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name @@ -1243,7 +1269,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): unet_config=self.unet.config, **kwargs, ) - is_correct_format = all("lora" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 6f8fb80fc2..11e3311a64 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -14,7 +14,7 @@ import re -from ..utils import logging +from ..utils import is_peft_version, logging logger = logging.get_logger(__name__) @@ -128,6 +128,15 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ te_state_dict = {} te2_state_dict = {} network_alphas = {} + is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) + is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) + is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) + + if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) # every down weight has a corresponding up weight and potentially an alpha weight lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] @@ -198,6 +207,12 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + if is_unet_dora_lora: + dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." + unet_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): if lora_name.startswith(("lora_te_", "lora_te1_")): key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" @@ -229,6 +244,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): + dora_scale_key_to_replace_te = ( + "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." + ) + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + elif lora_name.startswith("lora_te2_"): + te2_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + # Rename the alphas so that they can be mapped appropriately. if lora_name_alpha in state_dict: alpha = state_dict.pop(lora_name_alpha).item() diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 4e2f07f2ba..cdc9203661 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -69,6 +69,7 @@ from .import_utils import ( is_note_seq_available, is_onnx_available, is_peft_available, + is_peft_version, is_scipy_available, is_tensorboard_available, is_torch_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index a3ee31c91c..f5f57d8a5c 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -628,6 +628,20 @@ def is_accelerate_version(operation: str, version: str): return compare_versions(parse(_accelerate_version), operation, version) +def is_peft_version(operation: str, version: str): + """ + Args: + Compares the current PEFT version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _peft_version: + return False + return compare_versions(parse(_peft_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Args: diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index fb37545eb4..718e0b46d8 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -171,6 +171,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) + use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) lora_config_kwargs = { "r": r, @@ -178,6 +179,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True "rank_pattern": rank_pattern, "alpha_pattern": alpha_pattern, "target_modules": target_modules, + "use_dora": use_dora, } return lora_config_kwargs diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 00b96495ca..35fc4210a9 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -47,6 +47,7 @@ UNET_TO_DIFFUSERS = { ".to_v_lora.up": ".to_v.lora_B", ".lora.up": ".lora_B", ".lora.down": ".lora_A", + ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector", } @@ -104,6 +105,10 @@ DIFFUSERS_OLD_TO_DIFFUSERS = { ".to_v_lora.down": ".v_proj.lora_linear_layer.down", ".to_out_lora.up": ".out_proj.lora_linear_layer.up", ".to_out_lora.down": ".out_proj.lora_linear_layer.down", + ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector", + ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector", + ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector", + ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector", } PEFT_TO_KOHYA_SS = { @@ -315,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") elif "unet" in kohya_key: kohya_key = kohya_key.replace("unet", "lora_unet") + elif "lora_magnitude_vector" in kohya_key: + kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names kohya_ss_state_dict[kohya_key] = weight diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 036bcd3a90..d603994b66 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -630,3 +630,21 @@ class LoraSDXLIntegrationTests(unittest.TestCase): expected_slice_scale = np.array([0.5456, 0.5466, 0.5487, 0.5458, 0.5469, 0.5454, 0.5446, 0.5479, 0.5487]) max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) assert max_diff < 1e-3 + + @nightly + def test_integration_logits_for_dora_lora(self): + pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya") + pipeline.enable_model_cpu_offload() + + images = pipeline( + "photo of ohwx dog", + num_inference_steps=10, + generator=torch.manual_seed(0), + output_type="np", + ).images + + predicted_slice = images[0, -3:, -3:, -1].flatten() + expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516]) + max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) + assert max_diff < 1e-3 diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 6be602a77b..e6799be5bc 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -72,7 +72,7 @@ class PeftLoraLoaderMixinTests: unet_kwargs = None vae_kwargs = None - def get_dummy_components(self, scheduler_cls=None): + def get_dummy_components(self, scheduler_cls=None, use_dora=False): scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls rank = 4 @@ -96,10 +96,15 @@ class PeftLoraLoaderMixinTests: lora_alpha=rank, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False, + use_dora=use_dora, ) unet_lora_config = LoraConfig( - r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + r=rank, + lora_alpha=rank, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, ) if self.has_two_text_encoders: @@ -1074,6 +1079,37 @@ class PeftLoraLoaderMixinTests: "Fused lora should not change the output", ) + @require_peft_version_greater(peft_version="0.9.0") + def test_simple_inference_with_dora(self): + for scheduler_cls in [DDIMScheduler, LCMScheduler]: + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls, use_dora=True) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_dora_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + pipe.unet.add_adapter(unet_lora_config) + + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertFalse( + np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), + "DoRA lora should change the output", + ) + @unittest.skip("This is failing for now - need to investigate") def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): """