diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index e1edc77423..ec77718d16 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -13,7 +13,6 @@
# limitations under the License.
import os
import re
-import warnings
from collections import defaultdict
from contextlib import nullcontext
from io import BytesIO
@@ -33,7 +32,6 @@ from .utils import (
_get_model_file,
deprecate,
is_accelerate_available,
- is_accelerate_version,
is_omegaconf_available,
is_transformers_available,
logging,
@@ -308,6 +306,9 @@ class UNet2DConditionLoadersMixin:
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
+
+ _pipeline = kwargs.pop("_pipeline", None)
+
is_network_alphas_none = network_alphas is None
allow_pickle = False
@@ -461,6 +462,7 @@ class UNet2DConditionLoadersMixin:
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
+
elif is_custom_diffusion:
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -490,19 +492,44 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim=cross_attention_dim,
)
attn_processors[key].load_state_dict(value_dict)
-
- self.set_attn_processor(attn_processors)
else:
raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
)
+ #
+
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
@@ -1072,26 +1099,21 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
- # Remove any existing hooks.
- is_model_cpu_offload = False
- is_sequential_cpu_offload = False
- recurive = False
- for _, component in self.components.items():
- if isinstance(component, nn.Module):
- if hasattr(component, "_hf_hook"):
- is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
- is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
- logger.info(
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
- )
- recurive = is_sequential_cpu_offload
- remove_hook_from_module(component, recurse=recurive)
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
- state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(
- state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
+ state_dict,
+ network_alphas=network_alphas,
+ unet=self.unet,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ _pipeline=self,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -1099,14 +1121,9 @@ class LoraLoaderMixin:
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
+ _pipeline=self,
)
- # Offload back.
- if is_model_cpu_offload:
- self.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- self.enable_sequential_cpu_offload()
-
@classmethod
def lora_state_dict(
cls,
@@ -1403,7 +1420,7 @@ class LoraLoaderMixin:
return new_state_dict
@classmethod
- def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
+ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1445,13 +1462,22 @@ class LoraLoaderMixin:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
- warnings.warn(warn_message)
+ logger.warn(warn_message)
- unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
+ unet.load_attn_procs(
+ state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
+ )
@classmethod
def load_lora_into_text_encoder(
- cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ low_cpu_mem_usage=None,
+ _pipeline=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1561,11 +1587,15 @@ class LoraLoaderMixin:
low_cpu_mem_usage=low_cpu_mem_usage,
)
- # set correct dtype & device
- text_encoder_lora_state_dict = {
- k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
- for k, v in text_encoder_lora_state_dict.items()
- }
+ is_pipeline_offloaded = _pipeline is not None and any(
+ isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
+ )
+ if is_pipeline_offloaded and low_cpu_mem_usage:
+ low_cpu_mem_usage = True
+ logger.info(
+ f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
+ )
+
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
@@ -1581,8 +1611,33 @@ class LoraLoaderMixin:
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
+ #
+
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
@@ -2652,31 +2707,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
- # Remove any existing hooks.
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
- from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
- else:
- raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
-
- is_model_cpu_offload = False
- is_sequential_cpu_offload = False
- for _, component in self.components.items():
- if isinstance(component, torch.nn.Module):
- if hasattr(component, "_hf_hook"):
- is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
- is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
- logger.info(
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
- )
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
-
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
- self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
@@ -2685,6 +2726,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
+ _pipeline=self,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -2695,14 +2737,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
+ _pipeline=self,
)
- # Offload back.
- if is_model_cpu_offload:
- self.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- self.enable_sequential_cpu_offload()
-
@classmethod
def save_lora_weights(
self,