1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Remove device synchronization when loading weights (#11927)

* update

* make style
This commit is contained in:
Aryan
2025-07-15 21:40:57 +05:30
committed by GitHub
parent 06fd427797
commit b73c738392
6 changed files with 6 additions and 24 deletions

View File

@@ -24,7 +24,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
@@ -431,10 +431,7 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

View File

@@ -46,7 +46,7 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
if is_transformers_available():
@@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint)

View File

@@ -19,7 +19,7 @@ from ..models.embeddings import (
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
if is_accelerate_available():
@@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_projection
@@ -158,7 +157,6 @@ class FluxTransformer2DLoadersMixin:
key_id += 1
empty_device_cache()
device_synchronize()
return attn_procs

View File

@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__)
@@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin:
)
empty_device_cache()
device_synchronize()
return attn_procs
@@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_proj

View File

@@ -43,7 +43,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
@@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_projection
@@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin:
key_id += 2
empty_device_cache()
device_synchronize()
return attn_procs

View File

@@ -62,7 +62,7 @@ from ..utils.hub_utils import (
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
@@ -1540,10 +1540,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)