1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Core] fix offload behaviour when device_map is enabled. (#7919)

fix offload behaviour when device_map is enabled.
This commit is contained in:
Sayak Paul
2024-05-12 13:29:43 +02:00
committed by GitHub
parent ec9e88139a
commit 5bb38586a9
2 changed files with 15 additions and 14 deletions

View File

@@ -363,7 +363,7 @@ class LoraLoaderMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:

View File

@@ -419,19 +419,20 @@ class TextualInversionLoaderMixin:
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = (
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
if self.hf_device_map is None:
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = (
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# 7.2 save expected device and dtype
device = text_encoder.device