1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix broken group offloading with block_level for models with standalone layers (#12692)

* fix: group offloading to support standalone computational layers in block-level offloading

* test: for models with standalone and deeply nested layers in block-level offloading

* feat: support for block-level offloading in group offloading config

* fix: group offload block modules to AutoencoderKL and AutoencoderKLWan

* fix: update group offloading tests to use AutoencoderKL and adjust input dimensions

* refactor: streamline block offloading logic

* Apply style fixes

* update tests

* update

* fix for failing tests

* clean up

* revert to use skip_keys

* clean up

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
swappy
2025-12-05 18:54:05 +05:30
committed by GitHub
parent 8d415a6f48
commit f12d161d67
7 changed files with 355 additions and 69 deletions

View File

@@ -15,7 +15,7 @@
import hashlib
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from dataclasses import dataclass, replace
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
@@ -59,6 +59,9 @@ class GroupOffloadingConfig:
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
block_modules: Optional[List[str]] = None
exclude_kwargs: Optional[List[str]] = None
module_prefix: Optional[str] = ""
class ModuleGroup:
@@ -77,7 +80,7 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
group_id: Optional[Union[int, str]] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -322,7 +325,21 @@ class GroupOffloadingHook(ModelHook):
self.group.stream.synchronize()
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
# Some Autoencoder models use a feature cache that is passed through submodules
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
exclude_kwargs = self.config.exclude_kwargs or []
if exclude_kwargs:
moved_kwargs = send_to_device(
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
self.group.onload_device,
non_blocking=self.group.non_blocking,
)
kwargs.update(moved_kwargs)
else:
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs
def post_forward(self, module: torch.nn.Module, output):
@@ -455,6 +472,8 @@ def apply_group_offloading(
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[List[str]] = None,
exclude_kwargs: Optional[List[str]] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -512,6 +531,13 @@ def apply_group_offloading(
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
block_modules (`List[str]`, *optional*):
List of module names that should be treated as blocks for offloading. If provided, only these modules will
be considered for block-level offloading. If not provided, the default block detection logic will be used.
exclude_kwargs (`List[str]`, *optional*):
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
caching lists that need to maintain their object identity across forward passes. If not provided, will be
inferred from the module's `_skip_keys` attribute if it exists.
Example:
```python
@@ -553,6 +579,12 @@ def apply_group_offloading(
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
if block_modules is None:
block_modules = getattr(module, "_group_offload_block_modules", None)
if exclude_kwargs is None:
exclude_kwargs = getattr(module, "_skip_keys", None)
config = GroupOffloadingConfig(
onload_device=onload_device,
offload_device=offload_device,
@@ -563,6 +595,8 @@ def apply_group_offloading(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
)
_apply_group_offloading(module, config)
@@ -578,46 +612,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
done at the top-level blocks and modules specified in block_modules.
When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
module, recursively apply block offloading to it.
"""
if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
)
config.num_blocks_per_group = 1
# Create module groups for ModuleList and Sequential blocks
block_modules = set(config.block_modules) if config.block_modules is not None else set()
# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set()
unmatched_modules = []
matched_module_groups = []
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
unmatched_modules.append((name, submodule))
modules_with_group_offloading.add(name)
continue
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
for name, submodule in module.named_children():
# Check if this is an explicitly defined block module
if name in block_modules:
# Track submodule using a prefix to avoid filename collisions during disk offload.
# Without this, submodules sharing the same model class would be assigned identical
# filenames (derived from the class name).
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
submodule_config = replace(config, module_prefix=prefix)
_apply_group_offloading_block_level(submodule, submodule_config)
modules_with_group_offloading.add(name)
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# This is an unmatched module
unmatched_modules.append((name, submodule))
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
@@ -632,28 +686,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
# device when the forward pass is called.
# Create a group for the remaining unmatched submodules of the top-level
# module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:

View File

@@ -74,6 +74,7 @@ class AutoencoderKL(
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
@register_to_config
def __init__(

View File

@@ -619,6 +619,7 @@ class WanEncoder3d(nn.Module):
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
@@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
"""
_supports_gradient_checkpointing = False
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
@@ -1414,6 +1416,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:

View File

@@ -531,6 +531,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[str] = None,
exclude_kwargs: Optional[str] = None,
) -> None:
r"""
Activates group offloading for the current model.
@@ -570,6 +572,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
module=self,
onload_device=onload_device,
@@ -581,6 +584,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
)
def set_attention_backend(self, backend: str) -> None:

View File

@@ -19,6 +19,7 @@ import unittest
import torch
from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
@@ -149,6 +150,74 @@ class LayerOutputTrackerHook(ModelHook):
return output
# Model with only standalone computational layers at top level
class DummyModelWithStandaloneLayers(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()
self.layer1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(hidden_features, hidden_features)
self.layer3 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
x = self.layer3(x)
return x
# Model with deeply nested structure
class DummyModelWithDeeplyNestedBlocks(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()
self.input_layer = torch.nn.Linear(in_features, hidden_features)
self.container = ContainerWithNestedModuleList(hidden_features)
self.output_layer = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.input_layer(x)
x = self.container(x)
x = self.output_layer(x)
return x
class ContainerWithNestedModuleList(torch.nn.Module):
def __init__(self, features: int) -> None:
super().__init__()
# Top-level computational layer
self.proj_in = torch.nn.Linear(features, features)
# Nested container with ModuleList
self.nested_container = NestedContainer(features)
# Another top-level computational layer
self.proj_out = torch.nn.Linear(features, features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj_in(x)
x = self.nested_container(x)
x = self.proj_out(x)
return x
class NestedContainer(torch.nn.Module):
def __init__(self, features: int) -> None:
super().__init__()
self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)])
self.norm = torch.nn.LayerNorm(features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
in_features = 64
@@ -340,7 +409,7 @@ class GroupOffloadTests(unittest.TestCase):
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
num_repeats = 4
num_repeats = 2
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
@@ -362,3 +431,138 @@ class GroupOffloadTests(unittest.TestCase):
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
def test_vae_like_model_without_streams(self):
"""Test VAE-like model with block-level offloading but without streams."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False)
x = torch.randn(2, 3, 32, 32).to(torch_device)
with torch.no_grad():
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
)
def test_model_with_only_standalone_layers(self):
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
x = torch.randn(2, 64).to(torch_device)
with torch.no_grad():
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for model with standalone layers.",
)
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
x = torch.randn(2, 3, 32, 32).to(torch_device)
with torch.no_grad():
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match for standalone Conv layers with {offload_type}.",
)
def test_multiple_invocations_with_vae_like_model(self):
"""Test that multiple forward passes work correctly with VAE-like model."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
config = self.get_autoencoder_kl_config()
model = AutoencoderKL(**config)
model_ref = AutoencoderKL(**config)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
x = torch.randn(2, 3, 32, 32).to(torch_device)
with torch.no_grad():
for i in range(2):
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
def test_nested_container_parameters_offloading(self):
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
x = torch.randn(2, 64).to(torch_device)
with torch.no_grad():
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for nested parameters.",
)
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
"layers_per_block": 1,
}
return init_dict

View File

@@ -1791,7 +1791,6 @@ class ModelTesterMixin:
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
model.to(torch_device)
output_without_group_offloading = run_forward(model)
output_without_group_offloading = normalize_output(output_without_group_offloading)
@@ -1916,6 +1915,9 @@ class ModelTesterMixin:
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
block_modules=model._group_offload_block_modules
if hasattr(model, "_group_offload_block_modules")
else None,
)
if not is_correct:
if extra_files:

View File

@@ -1424,6 +1424,8 @@ if is_torch_available():
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: Optional[int] = None,
block_modules: Optional[List[str]] = None,
module_prefix: str = "",
) -> Set[str]:
expected_files = set()
@@ -1435,23 +1437,36 @@ if is_torch_available():
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
# Handle groups of ModuleList and Sequential blocks
block_modules_set = set(block_modules) if block_modules is not None else set()
modules_with_group_offloading = set()
unmatched_modules = []
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
unmatched_modules.append(module)
continue
if name in block_modules_set:
new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}."
submodule_files = _get_expected_safetensors_files(
submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix
)
expected_files.update(submodule_files)
modules_with_group_offloading.add(name)
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
if not current_modules:
continue
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
expected_files.add(get_hashed_filename(group_id))
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
if not current_modules:
continue
group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
expected_files.add(get_hashed_filename(group_id))
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
unmatched_modules.append(submodule)
# Handle the group for unmatched top-level modules and parameters
for module in unmatched_modules:
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group"))
elif offload_type == "leaf_level":
# Handle leaf-level module groups
@@ -1492,12 +1507,13 @@ if is_torch_available():
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: Optional[int] = None,
block_modules: Optional[List[str]] = None,
) -> bool:
if not os.path.isdir(offload_to_disk_path):
return False, None, None
expected_files = _get_expected_safetensors_files(
module, offload_to_disk_path, offload_type, num_blocks_per_group
module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules
)
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
missing_files = expected_files - actual_files