diff --git a/.github/workflows/pr_test_peft_backend.yml b/.github/workflows/pr_test_peft_backend.yml index b4915a3bf4..2e2f2201e7 100644 --- a/.github/workflows/pr_test_peft_backend.yml +++ b/.github/workflows/pr_test_peft_backend.yml @@ -111,3 +111,21 @@ jobs: -s -v \ --make-reports=tests_${{ matrix.config.report }} \ tests/lora/ + python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ + -s -v \ + --make-reports=tests_models_lora_${{ matrix.config.report }} \ + tests/models/ -k "lora" + + + - name: Failure short reports + if: ${{ failure() }} + run: | + cat reports/tests_${{ matrix.config.report }}_failures_short.txt + cat reports/tests_models_lora_${{ matrix.config.report }}_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: pr_${{ matrix.config.report }}_test_reports + path: reports \ No newline at end of file diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index d6c11e0507..ec69b56a9c 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -189,12 +189,17 @@ jobs: -s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \ --make-reports=tests_peft_cuda \ tests/lora/ + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "lora and not Flax and not Onnx and not PEFTLoRALoading" \ + --make-reports=tests_peft_cuda_models_lora \ + tests/models/ - name: Failure short reports if: ${{ failure() }} run: | cat reports/tests_peft_cuda_stats.txt cat reports/tests_peft_cuda_failures_short.txt + cat reports/tests_peft_cuda_models_lora_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index e089d202ee..1b313932ea 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -22,17 +22,14 @@ import torch from huggingface_hub import model_info from huggingface_hub.constants import HF_HUB_OFFLINE from huggingface_hub.utils import validate_hf_hub_args -from packaging import version from torch import nn -from .. import __version__ -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict +from ..models.modeling_utils import load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, convert_state_dict_to_diffusers, convert_state_dict_to_peft, - convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, get_peft_kwargs, @@ -119,13 +116,10 @@ class LoraLoaderMixin: if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - self.load_lora_into_unet( state_dict, network_alphas=network_alphas, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, - low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, _pipeline=self, ) @@ -136,7 +130,6 @@ class LoraLoaderMixin: if not hasattr(self, "text_encoder") else self.text_encoder, lora_scale=self.lora_scale, - low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, _pipeline=self, ) @@ -193,16 +186,8 @@ class LoraLoaderMixin: allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. - mirror (`str`, *optional*): - Mirror source to resolve accessibility issues if you're downloading a model in China. We do not - guarantee the timeliness or safety of the source, and you should refer to the mirror site for more - information. - + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -383,9 +368,7 @@ class LoraLoaderMixin: return (is_model_cpu_offload, is_sequential_cpu_offload) @classmethod - def load_lora_into_unet( - cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None - ): + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -395,14 +378,11 @@ class LoraLoaderMixin: into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. network_alphas (`Dict[str, float]`): - See `LoRALinearLayer` for more details. + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. @@ -410,94 +390,18 @@ class LoraLoaderMixin: if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) - if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): + if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") - - unet_keys = [k for k in keys if k.startswith(cls.unet_name)] - state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)] - network_alphas = { - k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - else: - # Otherwise, we're dealing with the old format. This means the `state_dict` should only - # contain the module names of the `unet` as its keys WITHOUT any prefix. - if not USE_PEFT_BACKEND: - warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." - logger.warning(warn_message) - - if len(state_dict.keys()) > 0: - if adapter_name in getattr(unet, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." - ) - - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if network_alphas is not None: - # The alphas state dict have the same structure as Unet, thus we convert it to peft format using - # `convert_unet_state_dict_to_peft` method. - network_alphas = convert_unet_state_dict_to_peft(network_alphas) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(unet) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name) - - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - unet.load_attn_procs( - state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline - ) + unet.load_attn_procs( + state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + ) @classmethod def load_lora_into_text_encoder( @@ -507,7 +411,6 @@ class LoraLoaderMixin: text_encoder, prefix=None, lora_scale=1.0, - low_cpu_mem_usage=None, adapter_name=None, _pipeline=None, ): @@ -527,11 +430,6 @@ class LoraLoaderMixin: lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. @@ -541,8 +439,6 @@ class LoraLoaderMixin: from peft import LoraConfig - low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -625,9 +521,7 @@ class LoraLoaderMixin: # Unsafe code /> @classmethod - def load_lora_into_transformer( - cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None - ): + def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -640,19 +534,12 @@ class LoraLoaderMixin: See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT - keys = list(state_dict.keys()) transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] @@ -846,22 +733,11 @@ class LoraLoaderMixin: >>> ... ``` """ - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - if not USE_PEFT_BACKEND: - if version.parse(__version__) > version.parse("0.23"): - logger.warning( - "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights," - "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." - ) + raise ValueError("PEFT backend is required for this method.") - for _, module in unet.named_modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) - else: - recurse_remove_peft_layers(unet) - if hasattr(unet, "peft_config"): - del unet.peft_config + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.unload_lora() # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index cf67da1cae..b02ff5a589 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -33,34 +33,32 @@ from ..models.embeddings import ( IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict +from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, + convert_unet_state_dict_to_peft, delete_adapter_layers, + get_adapter_name, + get_peft_kwargs, is_accelerate_available, + is_peft_version, is_torch_version, logging, set_adapter_layers, set_weights_and_activate_adapters, ) +from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .unet_loader_utils import _maybe_expand_lora_scales from .utils import AttnProcsLayers if is_accelerate_available(): - from accelerate import init_empty_weights from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module logger = logging.get_logger(__name__) -TEXT_ENCODER_NAME = "text_encoder" -UNET_NAME = "unet" - -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" -LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" - CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" @@ -79,7 +77,8 @@ class UNet2DConditionLoadersMixin: Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py) - and be a `torch.nn.Module` class. + and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install + `peft`: `pip install -U peft`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): @@ -110,20 +109,20 @@ class UNet2DConditionLoadersMixin: token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - mirror (`str`, *optional*): - Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not - guarantee the timeliness or safety of the source, and you should refer to the mirror site for more - information. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + adapter_name (`str`, *optional*, defaults to None): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. Example: @@ -139,9 +138,6 @@ class UNet2DConditionLoadersMixin: ) ``` """ - from ..models.attention_processor import CustomDiffusionAttnProcessor - from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer - cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", None) @@ -152,15 +148,9 @@ class UNet2DConditionLoadersMixin: subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. - # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning - network_alphas = kwargs.pop("network_alphas", None) - + adapter_name = kwargs.pop("adapter_name", None) _pipeline = kwargs.pop("_pipeline", None) - - is_network_alphas_none = network_alphas is None - + network_alphas = kwargs.pop("network_alphas", None) allow_pickle = False if use_safetensors is None: @@ -216,198 +206,196 @@ class UNet2DConditionLoadersMixin: else: state_dict = pretrained_model_name_or_path_or_dict - # fill attn processors - lora_layers_list = [] - - is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) + is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) + is_model_cpu_offload = False + is_sequential_cpu_offload = False - if is_lora: - # correct keys - state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas) - - if network_alphas is not None: - network_alphas_keys = list(network_alphas.keys()) - used_network_alphas_keys = set() - - lora_grouped_dict = defaultdict(dict) - mapped_network_alphas = {} - - all_keys = list(state_dict.keys()) - for key in all_keys: - value = state_dict.pop(key) - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - lora_grouped_dict[attn_processor_key][sub_key] = value - - # Create another `mapped_network_alphas` dictionary so that we can properly map them. - if network_alphas is not None: - for k in network_alphas_keys: - if k.replace(".alpha", "") in key: - mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)}) - used_network_alphas_keys.add(k) - - if not is_network_alphas_none: - if len(set(network_alphas_keys) - used_network_alphas_keys) > 0: - raise ValueError( - f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" - ) - - if len(state_dict) > 0: - raise ValueError( - f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" - ) - - for key, value_dict in lora_grouped_dict.items(): - attn_processor = self - for sub_key in key.split("."): - attn_processor = getattr(attn_processor, sub_key) - - # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers - # or add_{k,v,q,out_proj}_proj_lora layers. - rank = value_dict["lora.down.weight"].shape[0] - - if isinstance(attn_processor, LoRACompatibleConv): - in_features = attn_processor.in_channels - out_features = attn_processor.out_channels - kernel_size = attn_processor.kernel_size - - ctx = init_empty_weights if low_cpu_mem_usage else nullcontext - with ctx(): - lora = LoRAConv2dLayer( - in_features=in_features, - out_features=out_features, - rank=rank, - kernel_size=kernel_size, - stride=attn_processor.stride, - padding=attn_processor.padding, - network_alpha=mapped_network_alphas.get(key), - ) - elif isinstance(attn_processor, LoRACompatibleLinear): - ctx = init_empty_weights if low_cpu_mem_usage else nullcontext - with ctx(): - lora = LoRALinearLayer( - attn_processor.in_features, - attn_processor.out_features, - rank, - mapped_network_alphas.get(key), - ) - else: - raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") - - value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} - lora_layers_list.append((attn_processor, lora)) - - if low_cpu_mem_usage: - device = next(iter(value_dict.values())).device - dtype = next(iter(value_dict.values())).dtype - load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype) - else: - lora.load_state_dict(value_dict) - - elif is_custom_diffusion: - attn_processors = {} - custom_diffusion_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): - if len(value) == 0: - custom_diffusion_grouped_dict[key] = {} - else: - if "to_out" in key: - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - else: - attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) - custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value - - for key, value_dict in custom_diffusion_grouped_dict.items(): - if len(value_dict) == 0: - attn_processors[key] = CustomDiffusionAttnProcessor( - train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None - ) - else: - cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] - hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] - train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False - attn_processors[key] = CustomDiffusionAttnProcessor( - train_kv=True, - train_q_out=train_q_out, - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - ) - attn_processors[key].load_state_dict(value_dict) - elif USE_PEFT_BACKEND: - # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict` - # on the Unet - pass + if is_custom_diffusion: + attn_processors = self._process_custom_diffusion(state_dict=state_dict) + elif is_lora: + is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora( + state_dict=state_dict, + unet_identifier_key=self.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + ) else: raise ValueError( - f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." + f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training." ) # + + def _process_custom_diffusion(self, state_dict): + from ..models.attention_processor import CustomDiffusionAttnProcessor + + attn_processors = {} + custom_diffusion_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + if len(value) == 0: + custom_diffusion_grouped_dict[key] = {} + else: + if "to_out" in key: + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + else: + attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) + custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in custom_diffusion_grouped_dict.items(): + if len(value_dict) == 0: + attn_processors[key] = CustomDiffusionAttnProcessor( + train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None + ) + else: + cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] + hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] + train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False + attn_processors[key] = CustomDiffusionAttnProcessor( + train_kv=True, + train_q_out=train_q_out, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ) + attn_processors[key].load_state_dict(value_dict) + + return attn_processors + + def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): + # This method does the following things: + # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy + # format. For legacy format no filtering is applied. + # 2. Converts the `state_dict` to the `peft` compatible format. + # 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the + # `LoraConfig` specs. + # 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it. + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + unet_keys = [k for k in keys if k.startswith(unet_identifier_key)] + unet_state_dict = { + k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys + } + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)] + network_alphas = { + k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict + + if len(state_dict_to_be_used) > 0: + if adapter_name in getattr(self, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." + ) + + state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) + + if network_alphas is not None: + # The alphas state dict have the same structure as Unet, thus we convert it to peft format using + # `convert_unet_state_dict_to_peft` method. + network_alphas = convert_unet_state_dict_to_peft(network_alphas) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + return is_model_cpu_offload, is_sequential_cpu_offload + + @classmethod + # Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ is_model_cpu_offload = False is_sequential_cpu_offload = False - # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet` - if not USE_PEFT_BACKEND: - if _pipeline is not None: - for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: is_sequential_cpu_offload = ( - isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + isinstance(component._hf_hook, AlignDevicesHook) or hasattr(component._hf_hook, "hooks") and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) ) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - # only custom diffusion needs to set attn processors - if is_custom_diffusion: - self.set_attn_processor(attn_processors) - - # set lora layers - for target_module, lora_layer in lora_layers_list: - target_module.set_lora_layer(lora_layer) - - self.to(dtype=self.dtype, device=self.device) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): - is_new_lora_format = all( - key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() - ) - if is_new_lora_format: - # Strip the `"unet"` prefix. - is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) - if is_text_encoder_present: - warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." - logger.warning(warn_message) - unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] - state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} - - # change processor format to 'pure' LoRACompatibleLinear format - if any("processor" in k.split(".") for k in state_dict.keys()): - - def format_to_lora_compatible(key): - if "processor" not in key.split("."): - return key - return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora") - - state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()} - - if network_alphas is not None: - network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()} - return state_dict, network_alphas + return (is_model_cpu_offload, is_sequential_cpu_offload) def save_attn_procs( self, @@ -460,6 +448,23 @@ class UNet2DConditionLoadersMixin: logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + is_custom_diffusion = any( + isinstance( + x, + (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor), + ) + for (_, x) in self.attn_processors.items() + ) + if is_custom_diffusion: + state_dict = self._get_custom_diffusion_state_dict() + else: + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.") + + from peft.utils import get_peft_model_state_dict + + state_dict = get_peft_model_state_dict(self) + if save_function is None: if safe_serialization: @@ -471,36 +476,6 @@ class UNet2DConditionLoadersMixin: os.makedirs(save_directory, exist_ok=True) - is_custom_diffusion = any( - isinstance( - x, - (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor), - ) - for (_, x) in self.attn_processors.items() - ) - if is_custom_diffusion: - model_to_save = AttnProcsLayers( - { - y: x - for (y, x) in self.attn_processors.items() - if isinstance( - x, - ( - CustomDiffusionAttnProcessor, - CustomDiffusionAttnProcessor2_0, - CustomDiffusionXFormersAttnProcessor, - ), - ) - } - ) - state_dict = model_to_save.state_dict() - for name, attn in self.attn_processors.items(): - if len(attn.state_dict()) == 0: - state_dict[name] = {} - else: - model_to_save = AttnProcsLayers(self.attn_processors) - state_dict = model_to_save.state_dict() - if weight_name is None: if safe_serialization: weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE @@ -512,56 +487,84 @@ class UNet2DConditionLoadersMixin: save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") + def _get_custom_diffusion_state_dict(self): + from ..models.attention_processor import ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + ) + + model_to_save = AttnProcsLayers( + { + y: x + for (y, x) in self.attn_processors.items() + if isinstance( + x, + ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + ), + ) + } + ) + state_dict = model_to_save.state_dict() + for name, attn in self.attn_processors.items(): + if len(attn.state_dict()) == 0: + state_dict[name] = {} + + return state_dict + def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `fuse_lora()`.") + self.lora_scale = lora_scale self._safe_fusing = safe_fusing self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names)) def _fuse_lora_apply(self, module, adapter_names=None): - if not USE_PEFT_BACKEND: - if hasattr(module, "_fuse_lora"): - module._fuse_lora(self.lora_scale, self._safe_fusing) + from peft.tuners.tuners_utils import BaseTunerLayer - if adapter_names is not None: + merge_kwargs = {"safe_merge": self._safe_fusing} + + if isinstance(module, BaseTunerLayer): + if self.lora_scale != 1.0: + module.scale_layer(self.lora_scale) + + # For BC with prevous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: raise ValueError( - "The `adapter_names` argument is not supported in your environment. Please switch" - " to PEFT backend to use this argument by installing latest PEFT and transformers." - " `pip install -U peft transformers`" + "The `adapter_names` argument is not supported with your PEFT version. Please upgrade" + " to the latest version of PEFT. `pip install -U peft`" ) - else: - from peft.tuners.tuners_utils import BaseTunerLayer - merge_kwargs = {"safe_merge": self._safe_fusing} - - if isinstance(module, BaseTunerLayer): - if self.lora_scale != 1.0: - module.scale_layer(self.lora_scale) - - # For BC with prevous PEFT versions, we need to check the signature - # of the `merge` method to see if it supports the `adapter_names` argument. - supported_merge_kwargs = list(inspect.signature(module.merge).parameters) - if "adapter_names" in supported_merge_kwargs: - merge_kwargs["adapter_names"] = adapter_names - elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: - raise ValueError( - "The `adapter_names` argument is not supported with your PEFT version. Please upgrade" - " to the latest version of PEFT. `pip install -U peft`" - ) - - module.merge(**merge_kwargs) + module.merge(**merge_kwargs) def unfuse_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `unfuse_lora()`.") self.apply(self._unfuse_lora_apply) def _unfuse_lora_apply(self, module): - if not USE_PEFT_BACKEND: - if hasattr(module, "_unfuse_lora"): - module._unfuse_lora() - else: - from peft.tuners.tuners_utils import BaseTunerLayer + from peft.tuners.tuners_utils import BaseTunerLayer - if isinstance(module, BaseTunerLayer): - module.unmerge() + if isinstance(module, BaseTunerLayer): + module.unmerge() + + def unload_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `unload_lora()`.") + + from ..utils import recurse_remove_peft_layers + + recurse_remove_peft_layers(self) + if hasattr(self, "peft_config"): + del self.peft_config def set_adapters( self, diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 084b7b64f9..a82c97d150 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -903,17 +903,6 @@ class UNet2DConditionModel( if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def unload_lora(self): - """Unloads LoRA weights.""" - deprecate( - "unload_lora", - "0.28.0", - "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", - ) - for module in self.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) - def get_time_embed( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] ) -> Optional[torch.Tensor]: diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 331c8fba44..650c832fef 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -22,7 +22,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin -from ...utils import BaseOutput, deprecate, logging +from ...utils import BaseOutput, logging from ..activations import get_activation from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -546,18 +546,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora - def unload_lora(self): - """Unloads LoRA weights.""" - deprecate( - "unload_lora", - "0.28.0", - "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", - ) - for module in self.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) - def forward( self, sample: torch.Tensor, diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index ad33df964d..f6dcebf299 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -37,7 +37,9 @@ from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, floats_tensor, + is_peft_available, load_hf_numpy, + require_peft_backend, require_torch_accelerator, require_torch_accelerator_with_fp16, require_torch_accelerator_with_training, @@ -51,11 +53,38 @@ from diffusers.utils.testing_utils import ( from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +if is_peft_available(): + from peft import LoraConfig + from peft.tuners.tuners_utils import BaseTunerLayer + + logger = logging.get_logger(__name__) enable_full_determinism() +def get_unet_lora_config(): + rank = 4 + unet_lora_config = LoraConfig( + r=rank, + lora_alpha=rank, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + return unet_lora_config + + +def check_if_lora_correctly_set(model) -> bool: + """ + Checks if the LoRA layers are correctly set with peft + """ + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + def create_ip_adapter_state_dict(model): # "ip_adapter" (cross-attention weights) ip_cross_attn_state_dict = {} @@ -1005,6 +1034,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) + @require_peft_backend + def test_lora(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without LoRA + with torch.no_grad(): + non_lora_sample = model(**inputs_dict).sample + + unet_lora_config = get_unet_lora_config() + model.add_adapter(unet_lora_config) + + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + # forward pass with LoRA + with torch.no_grad(): + lora_sample = model(**inputs_dict).sample + + assert not torch.allclose( + non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4 + ), "LoRA injected UNet should produce different results." + + @require_peft_backend + def test_lora_serialization(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without LoRA + with torch.no_grad(): + non_lora_sample = model(**inputs_dict).sample + + unet_lora_config = get_unet_lora_config() + model.add_adapter(unet_lora_config) + + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + # forward pass with LoRA + with torch.no_grad(): + lora_sample_1 = model(**inputs_dict).sample + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_attn_procs(tmpdirname) + model.unload_lora() + model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + + with torch.no_grad(): + lora_sample_2 = model(**inputs_dict).sample + + assert not torch.allclose( + non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4 + ), "LoRA injected UNet should produce different results." + assert torch.allclose( + lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4 + ), "Loading from a saved checkpoint should produce identical results." + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase):