1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix: leaf_level offloading breaks after delete_adapters (#12639)

* Fix(peft): Re-apply group offloading after deleting adapters

* Test: Add regression test for group offloading + delete_adapters

* Test: Add assertions to verify output changes after deletion

* Test: Add try/finally to clean up group offloading hooks

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Aditya Borate
2025-12-03 17:39:11 +05:30
committed by GitHub
parent d0c54e5563
commit 5ab5946931
2 changed files with 52 additions and 0 deletions

View File

@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
@@ -794,6 +795,8 @@ class PeftAdapterMixin:
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
_maybe_remove_and_reapply_group_offloading(self)
def enable_lora_hotswap(
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
) -> None:

View File

@@ -28,6 +28,7 @@ from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from diffusers.hooks.group_offloading import _GROUP_OFFLOADING, apply_group_offloading
from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available
@@ -2367,3 +2368,51 @@ class PeftLoraLoaderMixinTests:
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))
@require_torch_accelerator
def test_lora_group_offloading_delete_adapters(self):
components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
try:
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.to(torch_device)
# Enable Group Offloading (leaf_level for more granular testing)
apply_group_offloading(
denoiser,
onload_device=torch_device,
offload_device="cpu",
offload_type="leaf_level",
)
pipe.load_lora_weights(tmpdirname, adapter_name="default")
out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Delete the adapter
pipe.delete_adapters("default")
out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3))
finally:
# Clean up the hooks to prevent state leak
if hasattr(denoiser, "_diffusers_hook"):
denoiser._diffusers_hook.remove_hook(_GROUP_OFFLOADING, recurse=True)