1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] fix: disabling hooks when loading loras. (#11896)

fix: disabling hooks when loading loras.
This commit is contained in:
Sayak Paul
2025-07-10 10:30:10 +05:30
committed by GitHub
parent 9f4d997d8f
commit 265840a098
2 changed files with 32 additions and 1 deletions

View File

@@ -470,7 +470,7 @@ def _func_optionally_disable_offloading(_pipeline):
for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)

View File

@@ -2510,3 +2510,34 @@ class PeftLoraLoaderMixinTests:
# materializes the test methods on invocation which cannot be overridden.
return
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
@require_torch_accelerator
def test_lora_loading_model_cpu_offload(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
# reinitialize the pipeline to mimic the inference workflow.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))