mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[IP Adapters] feat: allow low_cpu_mem_usage in ip adapter loading (#6946)
* feat: allow low_cpu_mem_usage in ip adapter loading * reduce the number of device placements. * documentation. * throw low_cpu_mem_usage warning only once from the main entry point.
This commit is contained in:
@@ -231,6 +231,9 @@ export_to_gif(frames, "gummy_bear.gif")
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
> [!TIP]
|
||||
> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.
|
||||
|
||||
## Specific use cases
|
||||
|
||||
IP-Adapter's image prompting and compatibility with other adapters and models makes it a versatile tool for a variety of use cases. This section covers some of the more popular applications of IP-Adapter, and we can't wait to see what you come up with!
|
||||
|
||||
@@ -19,8 +19,11 @@ import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from safetensors import safe_open
|
||||
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from ..utils import (
|
||||
_get_model_file,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
@@ -86,6 +89,11 @@ class IPAdapterMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
@@ -116,6 +124,22 @@ class IPAdapterMixin:
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
@@ -165,6 +189,7 @@ class IPAdapterMixin:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=Path(subfolder, "image_encoder").as_posix(),
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
@@ -175,9 +200,9 @@ class IPAdapterMixin:
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into unet
|
||||
# load ip-adapter into unet
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dicts)
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
|
||||
@@ -37,6 +37,7 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
delete_adapter_layers,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
@@ -168,15 +169,6 @@ class UNet2DConditionLoadersMixin:
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
@@ -694,9 +686,29 @@ class UNet2DConditionLoadersMixin:
|
||||
if hasattr(self, "peft_config"):
|
||||
self.peft_config.pop(adapter_name, None)
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
@@ -704,11 +716,12 @@ class UNet2DConditionLoadersMixin:
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
with init_context():
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
@@ -719,9 +732,10 @@ class UNet2DConditionLoadersMixin:
|
||||
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||
|
||||
image_projection = IPAdapterFullImageProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
||||
)
|
||||
with init_context():
|
||||
image_projection = IPAdapterFullImageProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
|
||||
@@ -737,13 +751,14 @@ class UNet2DConditionLoadersMixin:
|
||||
hidden_dims = state_dict["latents"].shape[2]
|
||||
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
heads=heads,
|
||||
num_queries=num_image_text_embeds,
|
||||
)
|
||||
with init_context():
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
heads=heads,
|
||||
num_queries=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("0.to", "2.to")
|
||||
@@ -765,10 +780,14 @@ class UNet2DConditionLoadersMixin:
|
||||
else:
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
image_projection.load_state_dict(updated_state_dict)
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict)
|
||||
else:
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
@@ -776,9 +795,29 @@ class UNet2DConditionLoadersMixin:
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 1
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for name in self.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
@@ -811,39 +850,49 @@ class UNet2DConditionLoadersMixin:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
|
||||
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
).to(dtype=self.dtype, device=self.device)
|
||||
with init_context():
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device = next(iter(value_dict.values())).device
|
||||
dtype = next(iter(value_dict.values())).dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
||||
|
||||
key_id += 2
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts):
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
# Set encoder_hid_proj after loading ip_adapter weights,
|
||||
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
# convert IP-Adapter Image Projection layers to diffusers
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
|
||||
image_projection_layer.to(device=self.device, dtype=self.dtype)
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user