mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Remove device synchronization when loading weights (#11927)
* update * make style
This commit is contained in:
@@ -24,7 +24,7 @@ from typing_extensions import Self
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
@@ -431,10 +431,7 @@ class FromOriginalModelMixin:
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from ..utils import (
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
@@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from ..models.embeddings import (
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -158,7 +157,6 @@ class FluxTransformer2DLoadersMixin:
|
||||
key_id += 1
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin:
|
||||
)
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
@@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_proj
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
@@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return image_projection
|
||||
|
||||
@@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin:
|
||||
key_id += 2
|
||||
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
return attn_procs
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ from ..utils.hub_utils import (
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import device_synchronize, empty_device_cache
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
@@ -1540,10 +1540,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
||||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
||||
|
||||
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
|
||||
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
|
||||
empty_device_cache()
|
||||
device_synchronize()
|
||||
|
||||
if offload_index is not None and len(offload_index) > 0:
|
||||
save_offload_index(offload_index, offload_folder)
|
||||
|
||||
Reference in New Issue
Block a user