diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 0280fc23f7..7a970c5c51 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -25,7 +25,6 @@ from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, - convert_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -53,9 +52,12 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, + "AuraFlowTransformer2DModel": lambda model_cls, weights: weights, "Lumina2Transformer2DModel": lambda model_cls, weights: weights, "WanTransformer3DModel": lambda model_cls, weights: weights, "CogView4Transformer2DModel": lambda model_cls, weights: weights, + "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, + "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, } @@ -79,33 +81,6 @@ def _maybe_raise_error_for_ambiguity(config): ) -def _maybe_adjust_config_for_control_lora(config): - """ - """ - - target_modules_before = config["target_modules"] - target_modules = [] - modules_to_save = [] - - for module in target_modules_before: - if module.endswith("weight"): - base_name = ".".join(module.split(".")[:-1]) - modules_to_save.append(base_name) - elif module.endswith("bias"): - base_name = ".".join(module.split(".")[:-1]) - if ".".join([base_name, "weight"]) in target_modules_before: - modules_to_save.append(base_name) - else: - target_modules.append(base_name) - else: - target_modules.append(module) - - config["target_modules"] = list(set(target_modules)) - config["modules_to_save"] = list(set(modules_to_save)) - - return config - - class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -256,7 +231,7 @@ class PeftAdapterMixin: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -269,13 +244,6 @@ class PeftAdapterMixin: "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." ) - # Control LoRA from SAI is different from BFL Control LoRA - # https://huggingface.co/stabilityai/control-lora/ - is_control_lora = "lora_controlnet" in state_dict - if is_control_lora: - del state_dict["lora_controlnet"] - state_dict = convert_control_lora_state_dict_to_peft(state_dict) - # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: @@ -294,12 +262,12 @@ class PeftAdapterMixin: if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + network_alphas = { + k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys + } lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) _maybe_raise_error_for_ambiguity(lora_config_kwargs) - if is_control_lora: - lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: @@ -365,7 +333,7 @@ class PeftAdapterMixin: new_sd[k] = v return new_sd - # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, + # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful, # we should also delete the `peft_config` associated to the `adapter_name`. try: if hotswap: @@ -379,7 +347,7 @@ class PeftAdapterMixin: config=lora_config, ) except Exception as e: - logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") + logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}") raise # the hotswap function raises if there are incompatible keys, so if we reach this point we can set # it to None @@ -414,7 +382,7 @@ class PeftAdapterMixin: module.delete_adapter(adapter_name) self.peft_config.pop(adapter_name) - logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") + logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}") raise warn_msg = "" @@ -747,7 +715,7 @@ class PeftAdapterMixin: if self.lora_scale != 1.0: module.scale_layer(self.lora_scale) - # For BC with prevous PEFT versions, we need to check the signature + # For BC with previous PEFT versions, we need to check the signature # of the `merge` method to see if it supports the `adapter_names` argument. supported_merge_kwargs = list(inspect.signature(module.merge).parameters) if "adapter_names" in supported_merge_kwargs: