From d7001400764acb8de5df343bbc4c54479c0e6ebe Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 3 Jan 2024 20:54:09 +0530 Subject: [PATCH] [LoRA deprecation] handle rest of the stuff related to deprecated lora stuff. (#6426) * handle rest of the stuff related to deprecated lora stuff. * fix: copies * don't modify the uNet in-place. * fix: temporal autoencoder. * manually remove lora layers. * don't copy unet. * alright * remove lora attn processors from unet3d * fix: unet3d. * styl * Empty-Commit --- .../controlnetxs/controlnetxs.py | 6 +-- src/diffusers/loaders/lora.py | 2 +- src/diffusers/models/attention_processor.py | 17 +------- .../models/autoencoders/autoencoder_kl.py | 10 ++--- .../autoencoder_kl_temporal_decoder.py | 10 ++--- .../autoencoders/consistency_decoder_vae.py | 10 ++--- src/diffusers/models/controlnet.py | 10 ++--- src/diffusers/models/prior_transformer.py | 10 ++--- src/diffusers/models/unet_2d_condition.py | 10 ++--- src/diffusers/models/unet_3d_condition.py | 10 ++--- src/diffusers/models/unet_motion_model.py | 10 ++--- src/diffusers/models/uvit_2d.py | 10 ++--- .../pipelines/audioldm2/modeling_audioldm2.py | 10 ++--- .../versatile_diffusion/modeling_text_unet.py | 10 ++--- .../wuerstchen/modeling_wuerstchen_prior.py | 10 ++--- tests/lora/test_lora_layers_old_backend.py | 39 +++++++++++++++---- 16 files changed, 83 insertions(+), 101 deletions(-) diff --git a/examples/research_projects/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py index c6419b44da..20c8d0fdf0 100644 --- a/examples/research_projects/controlnetxs/controlnetxs.py +++ b/examples/research_projects/controlnetxs/controlnetxs.py @@ -494,9 +494,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin): """ return self.control_model.attn_processors - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -509,7 +507,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin): processor. This is strongly recommended when setting trainable attention processors. """ - self.control_model.set_attn_processor(processor, _remove_lora) + self.control_model.set_attn_processor(processor) def set_default_attn_processor(self): """ diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index bbd01a9950..424e95f084 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -980,7 +980,7 @@ class LoraLoaderMixin: if not USE_PEFT_BACKEND: if version.parse(__version__) > version.parse("0.23"): - logger.warn( + logger.warning( "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights," "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." ) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 23a3e2bb37..ac9563e186 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -373,29 +373,14 @@ class Attention(nn.Module): self.set_processor(processor) - def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: + def set_processor(self, processor: "AttnProcessor") -> None: r""" Set the attention processor to use. Args: processor (`AttnProcessor`): The attention processor to use. - _remove_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to remove LoRA layers from the model. """ - if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: - deprecate( - "set_processor to offload LoRA", - "0.26.0", - "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", - ) - # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete - # We need to remove all LoRA layers - # Don't forget to remove ALL `_remove_lora` from the codebase - for module in self.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) - # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index ae2d90c548..10a3ae58de 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -182,9 +182,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -208,9 +206,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -232,7 +230,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) @apply_forward_hook def encode( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 0b7f8d1f53..dbafb4571d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -267,9 +267,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -293,9 +291,9 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -314,7 +312,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) @apply_forward_hook def encode( diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index d92423eafc..ca670fec4b 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -212,9 +212,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -238,9 +236,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -262,7 +260,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) @apply_forward_hook def encode( diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 3139bb2a5c..1102f4f9d3 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -534,9 +534,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -560,9 +558,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -584,7 +582,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 6c5e406ad3..8ada0a7c08 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -192,9 +192,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -218,9 +216,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -242,7 +240,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def forward( self, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 623e4d88d5..4554016bdd 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -643,9 +643,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) return processors - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -669,9 +667,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -692,7 +690,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def set_attention_slice(self, slice_size): r""" diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 3c76b5aa84..2fd629b2f8 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -375,9 +375,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) fn_recursive_set_attention_slice(module, reversed_slice_size) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -401,9 +399,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -465,7 +463,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 0bbc573e7d..b5f0302b4a 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -549,9 +549,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -575,9 +573,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -641,7 +639,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): diff --git a/src/diffusers/models/uvit_2d.py b/src/diffusers/models/uvit_2d.py index 14dd8aee8e..a49c77a51b 100644 --- a/src/diffusers/models/uvit_2d.py +++ b/src/diffusers/models/uvit_2d.py @@ -237,9 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin): return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -263,9 +261,9 @@ class UVit2DModel(ModelMixin, ConfigMixin): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -287,7 +285,7 @@ class UVit2DModel(ModelMixin, ConfigMixin): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) class UVit2DConvEmbed(nn.Module): diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index e855c2f0d6..d39b2c99dd 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -538,9 +538,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -564,9 +562,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -588,7 +586,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 7c9936a0bd..6f95112c3d 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -848,9 +848,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): return processors - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -874,9 +872,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -897,7 +895,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def set_attention_slice(self, slice_size): r""" diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index a7d9e32fb6..d4502639ce 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -91,9 +91,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -117,9 +115,9 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -141,7 +139,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index 7d6d301694..09bb87c851 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -61,7 +61,8 @@ from diffusers.utils.testing_utils import ( ) -def text_encoder_attn_modules(text_encoder): +def text_encoder_attn_modules(text_encoder: nn.Module): + """Fetches the attention modules from `text_encoder`.""" attn_modules = [] if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): @@ -75,7 +76,8 @@ def text_encoder_attn_modules(text_encoder): return attn_modules -def text_encoder_lora_state_dict(text_encoder): +def text_encoder_lora_state_dict(text_encoder: nn.Module): + """Returns the LoRA state dict of the `text_encoder`. Assumes that `_modify_text_encoder()` was already called on it.""" state_dict = {} for name, module in text_encoder_attn_modules(text_encoder): @@ -95,6 +97,8 @@ def text_encoder_lora_state_dict(text_encoder): def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): + """Creates and returns the LoRA state dict for the UNet.""" + # So that we accidentally don't end up using the in-place modified UNet. unet_lora_parameters = [] for attn_processor_name, attn_processor in unet.attn_processors.items(): @@ -145,10 +149,17 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) - return unet_lora_parameters, unet_lora_state_dict(unet) + unet_lora_sd = unet_lora_state_dict(unet) + # Unload LoRA. + for module in unet.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + return unet_lora_parameters, unet_lora_sd def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): + """Creates and returns the LoRA state dict for the 3D UNet.""" for attn_processor_name in unet.attn_processors.keys(): has_cross_attention = attn_processor_name.endswith("attn2.processor") and not ( attn_processor_name.startswith("transformer_in") or "temp_attentions" in attn_processor_name.split(".") @@ -216,10 +227,18 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): attn_module.to_v.lora_layer.up.weight += 1 attn_module.to_out[0].lora_layer.up.weight += 1 - return unet_lora_state_dict(unet) + unet_lora_sd = unet_lora_state_dict(unet) + + # Unload LoRA. + for module in unet.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + return unet_lora_sd def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): + """Randomizes the LoRA params if specified.""" if not isinstance(lora_attn_parameters, dict): with torch.no_grad(): for parameter in lora_attn_parameters: @@ -1441,6 +1460,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): class UNet2DConditionLoRAModelTests(unittest.TestCase): model_class = UNet2DConditionModel main_input_name = "sample" + lora_rank = 4 @property def dummy_input(self): @@ -1489,7 +1509,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase): with torch.no_grad(): sample1 = model(**inputs_dict).sample - _, lora_params = create_unet_lora_layers(model) + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) # make sure we can set a list of attention processors model.load_attn_procs(lora_params) @@ -1522,13 +1542,16 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase): with torch.no_grad(): old_sample = model(**inputs_dict).sample - _, lora_params = create_unet_lora_layers(model) + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) model.load_attn_procs(lora_params) with torch.no_grad(): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - model.set_default_attn_processor() + # Unload LoRA. + for module in model.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) with torch.no_grad(): new_sample = model(**inputs_dict).sample @@ -1552,7 +1575,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase): torch.manual_seed(0) model = self.model_class(**init_dict) model.to(torch_device) - _, lora_params = create_unet_lora_layers(model) + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) model.load_attn_procs(lora_params) # default