mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] fix: disabling hooks when loading loras. (#11896)
fix: disabling hooks when loading loras.
This commit is contained in:
@@ -470,7 +470,7 @@ def _func_optionally_disable_offloading(_pipeline):
|
||||
for _, component in _pipeline.components.items():
|
||||
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
|
||||
|
||||
|
||||
@@ -2510,3 +2510,34 @@ class PeftLoraLoaderMixinTests:
|
||||
# materializes the test methods on invocation which cannot be overridden.
|
||||
return
|
||||
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_lora_loading_model_cpu_offload(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
# reinitialize the pipeline to mimic the inference workflow.
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user