mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Peft] fix saving / loading when unet is not "unet" (#6046)
* [Peft] fix saving / loading when unet is not "unet" * Update src/diffusers/loaders/lora.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * undo stablediffusion-xl changes * use unet_name to get unet for lora helpers * use unet_name --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -149,9 +149,11 @@ class IPAdapterMixin:
|
||||
self.feature_extractor = CLIPImageProcessor()
|
||||
|
||||
# load ip-adapter into unet
|
||||
self.unet._load_ip_adapter_weights(state_dict)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dict)
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
for attn_processor in self.unet.attn_processors.values():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for attn_processor in unet.attn_processors.values():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@@ -912,10 +912,10 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
|
||||
@@ -975,6 +975,8 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
logger.warn(
|
||||
@@ -982,13 +984,13 @@ class LoraLoaderMixin:
|
||||
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
||||
)
|
||||
|
||||
for _, module in self.unet.named_modules():
|
||||
for _, module in unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
else:
|
||||
recurse_remove_peft_layers(self.unet)
|
||||
if hasattr(self.unet, "peft_config"):
|
||||
del self.unet.peft_config
|
||||
recurse_remove_peft_layers(unet)
|
||||
if hasattr(unet, "peft_config"):
|
||||
del unet.peft_config
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
@@ -1027,7 +1029,8 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
@@ -1080,13 +1083,14 @@ class LoraLoaderMixin:
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if unfuse_unet:
|
||||
if not USE_PEFT_BACKEND:
|
||||
self.unet.unfuse_lora()
|
||||
unet.unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in self.unet.modules():
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
@@ -1202,8 +1206,9 @@ class LoraLoaderMixin:
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[List[float]] = None,
|
||||
):
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
# Handle the UNET
|
||||
self.unet.set_adapters(adapter_names, adapter_weights)
|
||||
unet.set_adapters(adapter_names, adapter_weights)
|
||||
|
||||
# Handle the Text Encoder
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1216,7 +1221,8 @@ class LoraLoaderMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Disable unet adapters
|
||||
self.unet.disable_lora()
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.disable_lora()
|
||||
|
||||
# Disable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1229,7 +1235,8 @@ class LoraLoaderMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Enable unet adapters
|
||||
self.unet.enable_lora()
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.enable_lora()
|
||||
|
||||
# Enable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1251,7 +1258,8 @@ class LoraLoaderMixin:
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
# Delete unet adapters
|
||||
self.unet.delete_adapters(adapter_names)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.delete_adapters(adapter_names)
|
||||
|
||||
for adapter_name in adapter_names:
|
||||
# Delete text encoder adapters
|
||||
@@ -1284,8 +1292,8 @@ class LoraLoaderMixin:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
active_adapters = []
|
||||
|
||||
for module in self.unet.modules():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
@@ -1309,8 +1317,9 @@ class LoraLoaderMixin:
|
||||
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
|
||||
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
|
||||
|
||||
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
|
||||
set_adapters["unet"] = list(self.unet.peft_config.keys())
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
|
||||
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
@@ -1331,7 +1340,8 @@ class LoraLoaderMixin:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
# Handle the UNET
|
||||
for unet_module in self.unet.modules():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for unet_module in unet.modules():
|
||||
if isinstance(unet_module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
unet_module.lora_A[adapter_name].to(device)
|
||||
|
||||
Reference in New Issue
Block a user