1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Support dynamically loading/unloading loras with group offloading (#11804)

* update

* add test

* address review comments

* update

* fixes

* change decorator order to fix tests

* try fix

* fight tests
This commit is contained in:
Aryan
2025-06-27 23:20:53 +05:30
committed by GitHub
parent cdaf84a708
commit 76ec3d1fee
7 changed files with 289 additions and 175 deletions

View File

@@ -14,6 +14,8 @@
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
@@ -46,6 +48,24 @@ _SUPPORTED_PYTORCH_LAYERS = (
# fmt: on
class GroupOffloadingType(str, Enum):
BLOCK_LEVEL = "block_level"
LEAF_LEVEL = "leaf_level"
@dataclass
class GroupOffloadingConfig:
onload_device: torch.device
offload_device: torch.device
offload_type: GroupOffloadingType
non_blocking: bool
record_stream: bool
low_cpu_mem_usage: bool
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
class ModuleGroup:
def __init__(
self,
@@ -288,9 +308,12 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
self.group = group
self.next_group = next_group
self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
@@ -436,7 +459,7 @@ def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
@@ -478,7 +501,7 @@ def apply_group_offloading(
The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
offload_type (`str`, defaults to "block_level"):
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
@@ -521,6 +544,8 @@ def apply_group_offloading(
```
"""
offload_type = GroupOffloadingType(offload_type)
stream = None
if use_stream:
if torch.cuda.is_available():
@@ -532,84 +557,45 @@ def apply_group_offloading(
if not use_stream and record_stream:
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
config = GroupOffloadingConfig(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
)
_apply_group_offloading(module, config)
_apply_group_offloading_block_level(
module=module,
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
module=module,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
_apply_group_offloading_block_level(module, config)
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
_apply_group_offloading_leaf_level(module, config)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
assert False
def _apply_group_offloading_block_level(
module: torch.nn.Module,
num_blocks_per_group: int,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
if stream is not None and num_blocks_per_group != 1:
if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
)
num_blocks_per_group = 1
config.num_blocks_per_group = 1
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -621,19 +607,19 @@ def _apply_group_offloading_block_level(
modules_with_group_offloading.add(name)
continue
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
group = ModuleGroup(
modules=current_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
)
matched_module_groups.append(group)
@@ -643,7 +629,7 @@ def _apply_group_offloading_block_level(
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None)
_apply_group_offloading_hook(group_module, group, None, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -658,9 +644,9 @@ def _apply_group_offloading_block_level(
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
@@ -670,54 +656,19 @@ def _apply_group_offloading_block_level(
record_stream=False,
onload_self=True,
)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_leaf_level(
module: torch.nn.Module,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
reduce memory usage without any performance degradation.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
@@ -725,18 +676,18 @@ def _apply_group_offloading_leaf_level(
continue
group = ModuleGroup(
modules=[submodule],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
_apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -767,33 +718,32 @@ def _apply_group_offloading_leaf_level(
parameters = parent_to_parameters.get(name, [])
buffers = parent_to_buffers.get(name, [])
parent_module = module_dict[name]
assert getattr(parent_module, "_diffusers_hook", None) is None
group = ModuleGroup(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk_path=offload_to_disk_path,
offload_to_disk_path=config.offload_to_disk_path,
parameters=parameters,
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
_apply_group_offloading_hook(parent_module, group, None, config=config)
if stream is not None:
if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
# execution order and apply prefetching in the correct order.
unmatched_group = ModuleGroup(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=None,
@@ -801,23 +751,25 @@ def _apply_group_offloading_leaf_level(
non_blocking=False,
stream=None,
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group)
hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
@@ -825,13 +777,15 @@ def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group)
hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
@@ -898,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
)
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return True
return False
if hasattr(submodule, "_diffusers_hook"):
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None:
return group_offloading_hook
return None
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
return top_level_group_offload_hook is not None
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
if top_level_group_offload_hook is not None:
return top_level_group_offload_hook.config.onload_device
raise ValueError("Group offloading is not enabled for the provided module.")
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
r"""
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
case where user has applied group offloading at multiple levels, this function will not work as expected.
There is some performance penalty associated with doing this when non-default streams are used, because we need to
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
"""
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
if top_level_group_offload_hook is None:
return
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
_apply_group_offloading(module, top_level_group_offload_hook.config)

View File

@@ -25,6 +25,7 @@ import torch.nn as nn
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
adapter_name = get_adapter_name(text_encoder)
# <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
_pipeline
)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
if prefix is not None and not state_dict:
@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
if not isinstance(component, nn.Module):
continue
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
if not hasattr(component, "_hf_hook"):
continue
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
is_sequential_cpu_offload = is_sequential_cpu_offload or (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
if is_sequential_cpu_offload or is_model_cpu_offload:
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
if is_sequential_cpu_offload or is_model_cpu_offload:
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
class LoraBaseMixin:

View File

@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
@@ -256,7 +257,9 @@ class PeftAdapterMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -347,6 +350,10 @@ class PeftAdapterMixin:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
if prefix is not None and not state_dict:
@@ -687,6 +694,8 @@ class PeftAdapterMixin:
if hasattr(self, "peft_config"):
del self.peft_config
_maybe_remove_and_reapply_group_offloading(self)
def disable_lora(self):
"""
Disables the active LoRA layers of the underlying model.

View File

@@ -22,6 +22,7 @@ import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
@@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin:
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin:
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
state_dict=state_dict,
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
@@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin:
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None:
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline=_pipeline
)
# only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors)
@@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
@@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0:
@@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -389,7 +400,7 @@ class UNet2DConditionLoadersMixin:
if warn_msg:
logger.warning(warn_msg)
return is_model_cpu_offload, is_sequential_cpu_offload
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading

View File

@@ -16,6 +16,7 @@ import sys
import unittest
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
@@ -28,6 +29,7 @@ from diffusers import (
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_torch_accelerator,
)
@@ -127,6 +129,13 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

View File

@@ -18,10 +18,17 @@ import unittest
import numpy as np
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_torch_accelerator,
skip_mps,
torch_device,
)
sys.path.append(".")
@@ -141,6 +148,13 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"Loading from saved checkpoints should give same results.",
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

View File

@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import (
is_torch_version,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_transformers_version_greater,
skip_mps,
torch_device,
@@ -2355,3 +2356,73 @@ class PeftLoraLoaderMixinTests:
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
onload_device = torch_device
offload_device = torch.device("cpu")
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
# Test group offloading with load_lora_weights
denoiser.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=use_stream,
)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Test group offloading after removing the lora
pipe.unload_lora_weights()
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_2 is not None)
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
# Add the lora again and check if group offloading works
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_3 is not None)
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
# Skip this test if it is overwritten by child class. We need to do this because parameterized
# materializes the test methods on invocation which cannot be overridden.
return
self._test_group_offloading_inference_denoiser(offload_type, use_stream)