1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

device_map in load_model_dict_into_meta (#10851)

* `device_map` in `load_model_dict_into_meta`

* _LOW_CPU_MEM_USAGE_DEFAULT

* fix is_peft_version is_bitsandbytes_version
This commit is contained in:
hlky
2025-02-21 12:16:30 +00:00
committed by GitHub
parent b27d4edbe1
commit d75ea3c772
4 changed files with 23 additions and 18 deletions

View File

@@ -17,7 +17,7 @@ from ..models.embeddings import (
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
@@ -36,7 +36,7 @@ class FluxTransformer2DLoadersMixin:
Load layers into a [`FluxTransformer2DModel`].
"""
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -82,11 +82,12 @@ class FluxTransformer2DLoadersMixin:
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
@@ -151,15 +152,15 @@ class FluxTransformer2DLoadersMixin:
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device = self.device
device_map = {"": self.device}
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
key_id += 1
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]

View File

@@ -75,8 +75,9 @@ class SD3Transformer2DLoadersMixin:
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
device_map = {"": self.device}
load_model_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
return attn_procs
@@ -144,7 +145,8 @@ class SD3Transformer2DLoadersMixin:
if not low_cpu_mem_usage:
image_proj.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype)
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_proj

View File

@@ -30,7 +30,7 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
@@ -143,7 +143,7 @@ class UNet2DConditionLoadersMixin:
adapter_name = kwargs.pop("adapter_name", None)
_pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
@@ -540,7 +540,7 @@ class UNet2DConditionLoadersMixin:
return state_dict
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -753,11 +753,12 @@ class UNet2DConditionLoadersMixin:
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
@@ -846,13 +847,14 @@ class UNet2DConditionLoadersMixin:
else:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
device_map = {"": device}
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
key_id += 2
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]

View File

@@ -815,7 +815,7 @@ def is_peft_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _peft_version:
if not _peft_available:
return False
return compare_versions(parse(_peft_version), operation, version)
@@ -829,7 +829,7 @@ def is_bitsandbytes_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _bitsandbytes_version:
if not _bitsandbytes_available:
return False
return compare_versions(parse(_bitsandbytes_version), operation, version)