diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 97ea0d9e58..e089d202ee 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -363,7 +363,7 @@ class LoraLoaderMixin: is_model_cpu_offload = False is_sequential_cpu_offload = False - if _pipeline is not None: + if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index a9b9a9aae0..b6e1545e16 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -419,19 +419,20 @@ class TextualInversionLoaderMixin: # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again is_model_cpu_offload = False is_sequential_cpu_offload = False - for _, component in self.components.items(): - if isinstance(component, nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = ( - isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - logger.info( - "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + if self.hf_device_map is None: + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = ( + isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + logger.info( + "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) # 7.2 save expected device and dtype device = text_encoder.device