mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] fix offload behaviour when device_map is enabled. (#7919)
fix offload behaviour when device_map is enabled.
This commit is contained in:
@@ -363,7 +363,7 @@ class LoraLoaderMixin:
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None:
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
|
||||
@@ -419,19 +419,20 @@ class TextualInversionLoaderMixin:
|
||||
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
for _, component in self.components.items():
|
||||
if isinstance(component, nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
if self.hf_device_map is None:
|
||||
for _, component in self.components.items():
|
||||
if isinstance(component, nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
# 7.2 save expected device and dtype
|
||||
device = text_encoder.device
|
||||
|
||||
Reference in New Issue
Block a user