mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] allow loras to be loaded with low_cpu_mem_usage. (#9510)
* allow loras to be loaded with low_cpu_mem_usage. * add flux support but note https://github.com/huggingface/diffusers/pull/9510\#issuecomment-2378316687 * low_cpu_mem_usage. * fix-copies * fix-copies again * tests * _LOW_CPU_MEM_USAGE_DEFAULT_LORA * _peft_version default. * version checks. * version check. * version check. * version check. * require peft 0.13.1. * explicitly specify low_cpu_mem_usage=False. * docs. * transformers version 4.45.2. * update * fix * empty * better name initialize_dummy_state_dict. * doc todos. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style * fix-copies --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -75,6 +75,12 @@ image
|
||||
|
||||

|
||||
|
||||
<Tip>
|
||||
|
||||
By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Merge adapters
|
||||
|
||||
You can also merge different adapter checkpoints for inference to blend their styles together.
|
||||
|
||||
@@ -25,8 +25,11 @@ from ..utils import (
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
)
|
||||
@@ -39,6 +42,17 @@ from .lora_conversion_utils import (
|
||||
)
|
||||
|
||||
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
if (
|
||||
is_peft_available()
|
||||
and is_peft_version(">=", "0.13.1")
|
||||
and is_transformers_available()
|
||||
and is_transformers_version(">", "4.45.2")
|
||||
):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
@@ -83,15 +97,24 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
@@ -109,6 +132,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -119,6 +143,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -237,7 +262,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -255,10 +282,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -268,7 +301,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -281,6 +318,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -303,10 +341,25 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -377,6 +430,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -547,12 +601,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
@@ -573,7 +634,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
@@ -585,6 +651,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
@@ -597,6 +664,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -717,7 +785,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -735,10 +805,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -748,7 +824,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -762,6 +842,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -784,10 +865,25 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -858,6 +954,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -1126,15 +1223,22 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
@@ -1151,6 +1255,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
@@ -1163,6 +1268,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
@@ -1175,10 +1281,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -1192,7 +1301,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -1236,8 +1351,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
@@ -1266,6 +1385,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1288,10 +1408,25 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -1362,6 +1497,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -1667,10 +1803,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
@@ -1690,6 +1833,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
@@ -1702,10 +1846,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -1723,7 +1870,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -1772,8 +1925,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
@@ -1802,6 +1959,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1824,10 +1982,25 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -1898,6 +2071,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -2132,6 +2306,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -2154,10 +2329,25 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -2228,6 +2418,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -2416,15 +2607,22 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
@@ -2441,11 +2639,14 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -2459,7 +2660,13 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -2503,8 +2710,12 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
|
||||
@@ -115,6 +115,9 @@ class UNet2DConditionLoadersMixin:
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -142,8 +145,14 @@ class UNet2DConditionLoadersMixin:
|
||||
adapter_name = kwargs.pop("adapter_name", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
allow_pickle = False
|
||||
|
||||
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
@@ -209,6 +218,7 @@ class UNet2DConditionLoadersMixin:
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -268,7 +278,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
return attn_processors
|
||||
|
||||
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
|
||||
def _process_lora(
|
||||
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
|
||||
):
|
||||
# This method does the following things:
|
||||
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
||||
# format. For legacy format no filtering is applied.
|
||||
@@ -335,9 +347,12 @@ class UNet2DConditionLoadersMixin:
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
|
||||
@@ -388,6 +388,24 @@ def require_peft_version_greater(peft_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_transformers_version_greater(transformers_version):
|
||||
"""
|
||||
Decorator marking a test that requires transformers with a specific version, this would require some specific
|
||||
versions of PEFT and transformers.
|
||||
"""
|
||||
|
||||
def decorator(test_case):
|
||||
correct_transformers_version = is_transformers_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse(transformers_version)
|
||||
return unittest.skipUnless(
|
||||
correct_transformers_version,
|
||||
f"test requires transformers with the version greater than {transformers_version}",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_accelerate_version_greater(accelerate_version):
|
||||
def decorator(test_case):
|
||||
correct_accelerate_version = is_peft_available() and version.parse(
|
||||
|
||||
@@ -32,13 +32,14 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
@@ -65,6 +66,12 @@ def check_if_lora_correctly_set(model) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def initialize_dummy_state_dict(state_dict):
|
||||
if not all(v.device.type == "meta" for _, v in state_dict.items()):
|
||||
raise ValueError("`state_dict` has non-meta values.")
|
||||
return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class PeftLoraLoaderMixinTests:
|
||||
pipeline_class = None
|
||||
@@ -272,6 +279,136 @@ class PeftLoraLoaderMixinTests:
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
|
||||
@require_peft_version_greater("0.13.1")
|
||||
def test_low_cpu_mem_usage_with_injection(self):
|
||||
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
|
||||
)
|
||||
self.assertTrue(
|
||||
"meta" in {p.device.type for p in pipe.text_encoder.parameters()},
|
||||
"The LoRA params should be on 'meta' device.",
|
||||
)
|
||||
|
||||
te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
|
||||
set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
|
||||
self.assertTrue(
|
||||
"meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
|
||||
"No param should be on 'meta' device.",
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
self.assertTrue(
|
||||
"meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
|
||||
)
|
||||
|
||||
denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
|
||||
set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
|
||||
self.assertTrue(
|
||||
"meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
self.assertTrue(
|
||||
"meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
|
||||
"The LoRA params should be on 'meta' device.",
|
||||
)
|
||||
|
||||
te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
|
||||
set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
|
||||
self.assertTrue(
|
||||
"meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
|
||||
"No param should be on 'meta' device.",
|
||||
)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
output_lora = pipe(**inputs)[0]
|
||||
self.assertTrue(output_lora.shape == self.output_shape)
|
||||
|
||||
@require_peft_version_greater("0.13.1")
|
||||
@require_transformers_version_greater("4.45.1")
|
||||
def test_low_cpu_mem_usage_with_loading(self):
|
||||
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
|
||||
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
|
||||
|
||||
for module_name, module in modules_to_save.items():
|
||||
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
# Now, check for `low_cpu_mem_usage.`
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
|
||||
|
||||
for module_name, module in modules_to_save.items():
|
||||
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
|
||||
|
||||
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
|
||||
),
|
||||
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached on the text encoder + scale argument
|
||||
|
||||
Reference in New Issue
Block a user