mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)
* updates * updates * updates * updates * notebooks revert * fix-copies. * seeing * fix * revert * fixes * fixes * fixes * remove print * fix * conflicts ii. * updates * fixes * better filtering of prefix. --------- Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.info(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
|
||||
@@ -298,19 +298,15 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
@@ -559,31 +555,26 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
@@ -738,19 +729,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
@@ -1085,43 +1072,33 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
||||
if len(transformer_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
@@ -1541,18 +1518,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_lora_state_dict = {
|
||||
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
|
||||
k: state_dict.get(k)
|
||||
for k in list(state_dict.keys())
|
||||
if k.startswith(f"{self.transformer_name}.") and "lora" in k
|
||||
}
|
||||
transformer_norm_state_dict = {
|
||||
k: state_dict.pop(k)
|
||||
for k in list(state_dict.keys())
|
||||
if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
||||
if k.startswith(f"{self.transformer_name}.")
|
||||
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
||||
}
|
||||
|
||||
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
||||
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
||||
)
|
||||
has_param_with_expanded_shape = False
|
||||
if len(transformer_lora_state_dict) > 0:
|
||||
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
||||
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
||||
)
|
||||
|
||||
if has_param_with_expanded_shape:
|
||||
logger.info(
|
||||
@@ -1560,19 +1542,21 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
|
||||
"To get a comprehensive list of parameter names that were modified, enable debug logging."
|
||||
)
|
||||
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
||||
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
||||
)
|
||||
|
||||
if len(transformer_lora_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
transformer_lora_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
||||
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
||||
)
|
||||
for k in transformer_lora_state_dict:
|
||||
state_dict.update({k: transformer_lora_state_dict[k]})
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
if len(transformer_norm_state_dict) > 0:
|
||||
transformer._transformer_norm_layers = self._load_norm_into_transformer(
|
||||
@@ -1581,18 +1565,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
discard_original_layers=False,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
@@ -1625,17 +1607,14 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
keys = list(state_dict.keys())
|
||||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
||||
if transformer_present:
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_norm_into_transformer(
|
||||
@@ -2174,17 +2153,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
keys = list(state_dict.keys())
|
||||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
||||
if transformer_present:
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
||||
|
||||
@@ -235,10 +235,7 @@ class PeftAdapterMixin:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
|
||||
if prefix is not None:
|
||||
keys = list(state_dict.keys())
|
||||
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
|
||||
if len(model_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
@@ -355,6 +352,11 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.info(
|
||||
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {self.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
def save_lora_adapter(
|
||||
self,
|
||||
save_directory,
|
||||
|
||||
@@ -371,9 +371,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
cap_logger.out.startswith(
|
||||
"The provided state dict contains normalization layers in addition to LoRA layers"
|
||||
)
|
||||
"The provided state dict contains normalization layers in addition to LoRA layers"
|
||||
in cap_logger.out
|
||||
)
|
||||
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
|
||||
|
||||
@@ -392,7 +391,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe.load_lora_weights(norm_state_dict)
|
||||
|
||||
self.assertTrue(
|
||||
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers")
|
||||
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
|
||||
)
|
||||
|
||||
def test_lora_parameter_expanded_shapes(self):
|
||||
|
||||
@@ -1948,6 +1948,50 @@ class PeftLoraLoaderMixinTests:
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def test_logs_info_when_no_lora_keys_found(self):
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
# Skip text encoder check for now as that is handled with `transformers`.
|
||||
components, _, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(no_op_state_dict)
|
||||
out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
|
||||
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
|
||||
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
|
||||
|
||||
# test only for text encoder
|
||||
for lora_module in self.pipeline_class._lora_loadable_modules:
|
||||
if "text_encoder" in lora_module:
|
||||
text_encoder = getattr(pipe, lora_module)
|
||||
if lora_module == "text_encoder":
|
||||
prefix = "text_encoder"
|
||||
elif lora_module == "text_encoder_2":
|
||||
prefix = "text_encoder_2"
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.pipeline_class.load_lora_into_text_encoder(
|
||||
no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
|
||||
)
|
||||
|
||||
def test_set_adapters_match_attention_kwargs(self):
|
||||
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
|
||||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
|
||||
|
||||
Reference in New Issue
Block a user