mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
device_map in load_model_dict_into_meta (#10851)
* `device_map` in `load_model_dict_into_meta` * _LOW_CPU_MEM_USAGE_DEFAULT * fix is_peft_version is_bitsandbytes_version
This commit is contained in:
@@ -17,7 +17,7 @@ from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
@@ -36,7 +36,7 @@ class FluxTransformer2DLoadersMixin:
|
||||
Load layers into a [`FluxTransformer2DModel`].
|
||||
"""
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
@@ -82,11 +82,12 @@ class FluxTransformer2DLoadersMixin:
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ..models.attention_processor import (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
@@ -151,15 +152,15 @@ class FluxTransformer2DLoadersMixin:
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device = self.device
|
||||
device_map = {"": self.device}
|
||||
dtype = self.dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
|
||||
|
||||
key_id += 1
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
|
||||
@@ -75,8 +75,9 @@ class SD3Transformer2DLoadersMixin:
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(
|
||||
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
|
||||
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
||||
)
|
||||
|
||||
return attn_procs
|
||||
@@ -144,7 +145,8 @@ class SD3Transformer2DLoadersMixin:
|
||||
if not low_cpu_mem_usage:
|
||||
image_proj.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype)
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_proj
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from ..models.embeddings import (
|
||||
IPAdapterPlusImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
@@ -143,7 +143,7 @@ class UNet2DConditionLoadersMixin:
|
||||
adapter_name = kwargs.pop("adapter_name", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
allow_pickle = False
|
||||
|
||||
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
||||
@@ -540,7 +540,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
return state_dict
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
@@ -753,11 +753,12 @@ class UNet2DConditionLoadersMixin:
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ..models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
@@ -846,13 +847,14 @@ class UNet2DConditionLoadersMixin:
|
||||
else:
|
||||
device = next(iter(value_dict.values())).device
|
||||
dtype = next(iter(value_dict.values())).dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
||||
device_map = {"": device}
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
|
||||
|
||||
key_id += 2
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
|
||||
@@ -815,7 +815,7 @@ def is_peft_version(operation: str, version: str):
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _peft_version:
|
||||
if not _peft_available:
|
||||
return False
|
||||
return compare_versions(parse(_peft_version), operation, version)
|
||||
|
||||
@@ -829,7 +829,7 @@ def is_bitsandbytes_version(operation: str, version: str):
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _bitsandbytes_version:
|
||||
if not _bitsandbytes_available:
|
||||
return False
|
||||
return compare_versions(parse(_bitsandbytes_version), operation, version)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user