From ed507680e35b7628bc11235255cfc58ad1101626 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 18 Sep 2023 22:46:28 +0100 Subject: [PATCH] [LoRA] don't break offloading for incompatible lora ckpts. (#5085) * don't break offloading for incompatible lora ckpts. * debugging * better condition. * fix * fix * fix * fix --------- Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 157 ++++++++++++++++++++++++--------------- 1 file changed, 97 insertions(+), 60 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e1edc77423..ec77718d16 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -13,7 +13,6 @@ # limitations under the License. import os import re -import warnings from collections import defaultdict from contextlib import nullcontext from io import BytesIO @@ -33,7 +32,6 @@ from .utils import ( _get_model_file, deprecate, is_accelerate_available, - is_accelerate_version, is_omegaconf_available, is_transformers_available, logging, @@ -308,6 +306,9 @@ class UNet2DConditionLoadersMixin: # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning network_alphas = kwargs.pop("network_alphas", None) + + _pipeline = kwargs.pop("_pipeline", None) + is_network_alphas_none = network_alphas is None allow_pickle = False @@ -461,6 +462,7 @@ class UNet2DConditionLoadersMixin: load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype) else: lora.load_state_dict(value_dict) + elif is_custom_diffusion: attn_processors = {} custom_diffusion_grouped_dict = defaultdict(dict) @@ -490,19 +492,44 @@ class UNet2DConditionLoadersMixin: cross_attention_dim=cross_attention_dim, ) attn_processors[key].load_state_dict(value_dict) - - self.set_attn_processor(attn_processors) else: raise ValueError( f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." ) + # + def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): is_new_lora_format = all( key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() @@ -1072,26 +1099,21 @@ class LoraLoaderMixin: kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ - # Remove any existing hooks. - is_model_cpu_offload = False - is_sequential_cpu_offload = False - recurive = False - for _, component in self.components.items(): - if isinstance(component, nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - recurive = is_sequential_cpu_offload - remove_hook_from_module(component, recurse=recurive) + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet( - state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage + state_dict, + network_alphas=network_alphas, + unet=self.unet, + low_cpu_mem_usage=low_cpu_mem_usage, + _pipeline=self, ) self.load_lora_into_text_encoder( state_dict, @@ -1099,14 +1121,9 @@ class LoraLoaderMixin: text_encoder=self.text_encoder, lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, + _pipeline=self, ) - # Offload back. - if is_model_cpu_offload: - self.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - self.enable_sequential_cpu_offload() - @classmethod def lora_state_dict( cls, @@ -1403,7 +1420,7 @@ class LoraLoaderMixin: return new_state_dict @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None): + def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -1445,13 +1462,22 @@ class LoraLoaderMixin: # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." - warnings.warn(warn_message) + logger.warn(warn_message) - unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage) + unet.load_attn_procs( + state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline + ) @classmethod def load_lora_into_text_encoder( - cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + low_cpu_mem_usage=None, + _pipeline=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1561,11 +1587,15 @@ class LoraLoaderMixin: low_cpu_mem_usage=low_cpu_mem_usage, ) - # set correct dtype & device - text_encoder_lora_state_dict = { - k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) - for k, v in text_encoder_lora_state_dict.items() - } + is_pipeline_offloaded = _pipeline is not None and any( + isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values() + ) + if is_pipeline_offloaded and low_cpu_mem_usage: + low_cpu_mem_usage = True + logger.info( + f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced." + ) + if low_cpu_mem_usage: device = next(iter(text_encoder_lora_state_dict.values())).device dtype = next(iter(text_encoder_lora_state_dict.values())).dtype @@ -1581,8 +1611,33 @@ class LoraLoaderMixin: f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" ) + # + @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. @@ -2652,31 +2707,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. - # Remove any existing hooks. - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - else: - raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") - - is_model_cpu_offload = False - is_sequential_cpu_offload = False - for _, component in self.components.items(): - if isinstance(component, torch.nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, ) - self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( @@ -2685,6 +2726,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, + _pipeline=self, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} @@ -2695,14 +2737,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): text_encoder=self.text_encoder_2, prefix="text_encoder_2", lora_scale=self.lora_scale, + _pipeline=self, ) - # Offload back. - if is_model_cpu_offload: - self.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - self.enable_sequential_cpu_offload() - @classmethod def save_lora_weights( self,