mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support dynamically loading/unloading loras with group offloading (#11804)
* update * add test * address review comments * update * fixes * change decorator order to fix tests * try fix * fight tests
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import safetensors.torch
|
||||
@@ -46,6 +48,24 @@ _SUPPORTED_PYTORCH_LAYERS = (
|
||||
# fmt: on
|
||||
|
||||
|
||||
class GroupOffloadingType(str, Enum):
|
||||
BLOCK_LEVEL = "block_level"
|
||||
LEAF_LEVEL = "leaf_level"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupOffloadingConfig:
|
||||
onload_device: torch.device
|
||||
offload_device: torch.device
|
||||
offload_type: GroupOffloadingType
|
||||
non_blocking: bool
|
||||
record_stream: bool
|
||||
low_cpu_mem_usage: bool
|
||||
num_blocks_per_group: Optional[int] = None
|
||||
offload_to_disk_path: Optional[str] = None
|
||||
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -288,9 +308,12 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
|
||||
def __init__(
|
||||
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
|
||||
) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
self.config = config
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
@@ -436,7 +459,7 @@ def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
offload_type: str = "block_level",
|
||||
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -478,7 +501,7 @@ def apply_group_offloading(
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
||||
offload_type (`str`, defaults to "block_level"):
|
||||
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
|
||||
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
||||
"block_level".
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
@@ -521,6 +544,8 @@ def apply_group_offloading(
|
||||
```
|
||||
"""
|
||||
|
||||
offload_type = GroupOffloadingType(offload_type)
|
||||
|
||||
stream = None
|
||||
if use_stream:
|
||||
if torch.cuda.is_available():
|
||||
@@ -532,84 +557,45 @@ def apply_group_offloading(
|
||||
|
||||
if not use_stream and record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
||||
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
|
||||
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
config = GroupOffloadingConfig(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
)
|
||||
_apply_group_offloading(module, config)
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module=module,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(
|
||||
module=module,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
|
||||
_apply_group_offloading_block_level(module, config)
|
||||
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
|
||||
_apply_group_offloading_leaf_level(module, config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
assert False
|
||||
|
||||
|
||||
def _apply_group_offloading_block_level(
|
||||
module: torch.nn.Module,
|
||||
num_blocks_per_group: int,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
if stream is not None and num_blocks_per_group != 1:
|
||||
|
||||
if config.stream is not None and config.num_blocks_per_group != 1:
|
||||
logger.warning(
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
||||
)
|
||||
num_blocks_per_group = 1
|
||||
config.num_blocks_per_group = 1
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -621,19 +607,19 @@ def _apply_group_offloading_block_level(
|
||||
modules_with_group_offloading.add(name)
|
||||
continue
|
||||
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
for i in range(0, len(submodule), config.num_blocks_per_group):
|
||||
current_modules = submodule[i : i + config.num_blocks_per_group]
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -643,7 +629,7 @@ def _apply_group_offloading_block_level(
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, None)
|
||||
_apply_group_offloading_hook(group_module, group, None, config=config)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
@@ -658,9 +644,9 @@ def _apply_group_offloading_block_level(
|
||||
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
@@ -670,54 +656,19 @@ def _apply_group_offloading_block_level(
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
)
|
||||
if stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(
|
||||
module: torch.nn.Module,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
||||
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
||||
reduce memory usage without any performance degradation.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
@@ -725,18 +676,18 @@ def _apply_group_offloading_leaf_level(
|
||||
continue
|
||||
group = ModuleGroup(
|
||||
modules=[submodule],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=submodule,
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
_apply_group_offloading_hook(submodule, group, None, config=config)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
||||
@@ -767,33 +718,32 @@ def _apply_group_offloading_leaf_level(
|
||||
parameters = parent_to_parameters.get(name, [])
|
||||
buffers = parent_to_buffers.get(name, [])
|
||||
parent_module = module_dict[name]
|
||||
assert getattr(parent_module, "_diffusers_hook", None) is None
|
||||
group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_leader=parent_module,
|
||||
onload_leader=parent_module,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
_apply_group_offloading_hook(parent_module, group, None, config=config)
|
||||
|
||||
if stream is not None:
|
||||
if config.stream is not None:
|
||||
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
||||
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
||||
# execution order and apply prefetching in the correct order.
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=None,
|
||||
@@ -801,23 +751,25 @@ def _apply_group_offloading_leaf_level(
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
|
||||
@@ -825,13 +777,15 @@ def _apply_lazy_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
||||
@@ -898,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
|
||||
)
|
||||
|
||||
|
||||
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||||
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return True
|
||||
return False
|
||||
if hasattr(submodule, "_diffusers_hook"):
|
||||
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
|
||||
if group_offloading_hook is not None:
|
||||
return group_offloading_hook
|
||||
return None
|
||||
|
||||
|
||||
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
return top_level_group_offload_hook is not None
|
||||
|
||||
|
||||
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
if top_level_group_offload_hook is not None:
|
||||
return top_level_group_offload_hook.config.onload_device
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
|
||||
|
||||
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
|
||||
r"""
|
||||
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
|
||||
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
|
||||
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
|
||||
|
||||
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
|
||||
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
|
||||
case where user has applied group offloading at multiple levels, this function will not work as expected.
|
||||
|
||||
There is some performance penalty associated with doing this when non-default streams are used, because we need to
|
||||
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
|
||||
"""
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
|
||||
if top_level_group_offload_hook is None:
|
||||
return
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
|
||||
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
|
||||
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
|
||||
|
||||
_apply_group_offloading(module, top_level_group_offload_hook.config)
|
||||
|
||||
@@ -25,6 +25,7 @@ import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
|
||||
from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# <Unsafe code
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
if not isinstance(component, nn.Module):
|
||||
continue
|
||||
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
|
||||
if not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
|
||||
is_sequential_cpu_offload = is_sequential_cpu_offload or (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
if is_sequential_cpu_offload or is_model_cpu_offload:
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
if is_sequential_cpu_offload or is_model_cpu_offload:
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
for _, component in _pipeline.components.items():
|
||||
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -256,7 +257,9 @@ class PeftAdapterMixin:
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error.
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -347,6 +350,10 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
@@ -687,6 +694,8 @@ class PeftAdapterMixin:
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
|
||||
_maybe_remove_and_reapply_group_offloading(self)
|
||||
|
||||
def disable_lora(self):
|
||||
"""
|
||||
Disables the active LoRA layers of the underlying model.
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFaceIDImageProjection,
|
||||
@@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin:
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
if is_lora:
|
||||
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
||||
@@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin:
|
||||
if is_custom_diffusion:
|
||||
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
||||
elif is_lora:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
|
||||
state_dict=state_dict,
|
||||
unet_identifier_key=self.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
@@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
|
||||
if is_custom_diffusion and _pipeline is not None:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline=_pipeline
|
||||
)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
self.set_attn_processor(attn_processors)
|
||||
@@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
def _process_custom_diffusion(self, state_dict):
|
||||
@@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
||||
|
||||
if len(state_dict_to_be_used) > 0:
|
||||
@@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -389,7 +400,7 @@ class UNet2DConditionLoadersMixin:
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
|
||||
@@ -16,6 +16,7 @@ import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
@@ -28,6 +29,7 @@ from diffusers import (
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
)
|
||||
|
||||
|
||||
@@ -127,6 +129,13 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@@ -18,10 +18,17 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
@@ -141,6 +148,13 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import (
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_accelerator,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
@@ -2355,3 +2356,73 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
|
||||
|
||||
onload_device = torch_device
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
# Test group offloading with load_lora_weights
|
||||
denoiser.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=1,
|
||||
use_stream=use_stream,
|
||||
)
|
||||
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_1 is not None)
|
||||
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# Test group offloading after removing the lora
|
||||
pipe.unload_lora_weights()
|
||||
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_2 is not None)
|
||||
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
|
||||
|
||||
# Add the lora again and check if group offloading works
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
check_if_lora_correctly_set(denoiser)
|
||||
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_3 is not None)
|
||||
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
for cls in inspect.getmro(self.__class__):
|
||||
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
|
||||
# Skip this test if it is overwritten by child class. We need to do this because parameterized
|
||||
# materializes the test methods on invocation which cannot be overridden.
|
||||
return
|
||||
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
Reference in New Issue
Block a user