1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] don't break offloading for incompatible lora ckpts. (#5085)

* don't break offloading for incompatible lora ckpts.

* debugging

* better condition.

* fix

* fix

* fix

* fix

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sayak Paul
2023-09-18 22:46:28 +01:00
committed by Patrick von Platen
parent 92f6693b37
commit a2f0db52e3

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import os
import re
import warnings
from collections import defaultdict
from contextlib import nullcontext
from io import BytesIO
@@ -307,6 +306,9 @@ class UNet2DConditionLoadersMixin:
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
is_network_alphas_none = network_alphas is None
allow_pickle = False
@@ -460,6 +462,7 @@ class UNet2DConditionLoadersMixin:
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
elif is_custom_diffusion:
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -489,19 +492,44 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim=cross_attention_dim,
)
attn_processors[key].load_state_dict(value_dict)
self.set_attn_processor(attn_processors)
else:
raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
@@ -1060,26 +1088,21 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# Remove any existing hooks.
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recurive = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive)
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
state_dict,
network_alphas=network_alphas,
unet=self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -1087,14 +1110,9 @@ class LoraLoaderMixin:
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
def lora_state_dict(
cls,
@@ -1391,7 +1409,7 @@ class LoraLoaderMixin:
return new_state_dict
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1433,13 +1451,22 @@ class LoraLoaderMixin:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)
logger.warn(warn_message)
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
)
@classmethod
def load_lora_into_text_encoder(
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
_pipeline=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1549,11 +1576,15 @@ class LoraLoaderMixin:
low_cpu_mem_usage=low_cpu_mem_usage,
)
# set correct dtype & device
text_encoder_lora_state_dict = {
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items()
}
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
@@ -1569,8 +1600,33 @@ class LoraLoaderMixin:
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
@@ -2639,31 +2695,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# Remove any existing hooks.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
@@ -2672,6 +2714,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
_pipeline=self,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -2682,14 +2725,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
_pipeline=self,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
def save_lora_weights(
self,