From d75ea3c7728a8726f8a478bc9bca624f675cb586 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 21 Feb 2025 12:16:30 +0000 Subject: [PATCH] `device_map` in `load_model_dict_into_meta` (#10851) * `device_map` in `load_model_dict_into_meta` * _LOW_CPU_MEM_USAGE_DEFAULT * fix is_peft_version is_bitsandbytes_version --- src/diffusers/loaders/transformer_flux.py | 15 ++++++++------- src/diffusers/loaders/transformer_sd3.py | 6 ++++-- src/diffusers/loaders/unet.py | 16 +++++++++------- src/diffusers/utils/import_utils.py | 4 ++-- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index 52a48e56e7..38a8a7ebe2 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -17,7 +17,7 @@ from ..models.embeddings import ( ImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( is_accelerate_available, is_torch_version, @@ -36,7 +36,7 @@ class FluxTransformer2DLoadersMixin: Load layers into a [`FluxTransformer2DModel`]. """ - def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): if low_cpu_mem_usage: if is_accelerate_available(): from accelerate import init_empty_weights @@ -82,11 +82,12 @@ class FluxTransformer2DLoadersMixin: if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: - load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + device_map = {"": self.device} + load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) return image_projection - def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): from ..models.attention_processor import ( FluxIPAdapterJointAttnProcessor2_0, ) @@ -151,15 +152,15 @@ class FluxTransformer2DLoadersMixin: if not low_cpu_mem_usage: attn_procs[name].load_state_dict(value_dict) else: - device = self.device + device_map = {"": self.device} dtype = self.dtype - load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype) key_id += 1 return attn_procs - def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): if not isinstance(state_dicts, list): state_dicts = [state_dicts] diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index c120589610..ece17e6728 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -75,8 +75,9 @@ class SD3Transformer2DLoadersMixin: if not low_cpu_mem_usage: attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) else: + device_map = {"": self.device} load_model_dict_into_meta( - attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype + attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype ) return attn_procs @@ -144,7 +145,8 @@ class SD3Transformer2DLoadersMixin: if not low_cpu_mem_usage: image_proj.load_state_dict(updated_state_dict, strict=True) else: - load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype) + device_map = {"": self.device} + load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) return image_proj diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c68349c36d..1d8aba900c 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -30,7 +30,7 @@ from ..models.embeddings import ( IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -143,7 +143,7 @@ class UNet2DConditionLoadersMixin: adapter_name = kwargs.pop("adapter_name", None) _pipeline = kwargs.pop("_pipeline", None) network_alphas = kwargs.pop("network_alphas", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) allow_pickle = False if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): @@ -540,7 +540,7 @@ class UNet2DConditionLoadersMixin: return state_dict - def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): if low_cpu_mem_usage: if is_accelerate_available(): from accelerate import init_empty_weights @@ -753,11 +753,12 @@ class UNet2DConditionLoadersMixin: if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: - load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + device_map = {"": self.device} + load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) return image_projection - def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): from ..models.attention_processor import ( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, @@ -846,13 +847,14 @@ class UNet2DConditionLoadersMixin: else: device = next(iter(value_dict.values())).device dtype = next(iter(value_dict.values())).dtype - load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + device_map = {"": device} + load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype) key_id += 2 return attn_procs - def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): if not isinstance(state_dicts, list): state_dicts = [state_dicts] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 37535366ed..ae1b9cae6e 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -815,7 +815,7 @@ def is_peft_version(operation: str, version: str): version (`str`): A version string """ - if not _peft_version: + if not _peft_available: return False return compare_versions(parse(_peft_version), operation, version) @@ -829,7 +829,7 @@ def is_bitsandbytes_version(operation: str, version: str): version (`str`): A version string """ - if not _bitsandbytes_version: + if not _bitsandbytes_available: return False return compare_versions(parse(_bitsandbytes_version), operation, version)