mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix broken group offloading with block_level for models with standalone layers (#12692)
* fix: group offloading to support standalone computational layers in block-level offloading * test: for models with standalone and deeply nested layers in block-level offloading * feat: support for block-level offloading in group offloading config * fix: group offload block modules to AutoencoderKL and AutoencoderKLWan * fix: update group offloading tests to use AutoencoderKL and adjust input dimensions * refactor: streamline block offloading logic * Apply style fixes * update tests * update * fix for failing tests * clean up * revert to use skip_keys * clean up --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
import hashlib
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
@@ -59,6 +59,9 @@ class GroupOffloadingConfig:
|
||||
num_blocks_per_group: Optional[int] = None
|
||||
offload_to_disk_path: Optional[str] = None
|
||||
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
||||
block_modules: Optional[List[str]] = None
|
||||
exclude_kwargs: Optional[List[str]] = None
|
||||
module_prefix: Optional[str] = ""
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
@@ -77,7 +80,7 @@ class ModuleGroup:
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
group_id: Optional[int] = None,
|
||||
group_id: Optional[Union[int, str]] = None,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
self.offload_device = offload_device
|
||||
@@ -322,7 +325,21 @@ class GroupOffloadingHook(ModelHook):
|
||||
self.group.stream.synchronize()
|
||||
|
||||
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
|
||||
# Some Autoencoder models use a feature cache that is passed through submodules
|
||||
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
|
||||
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
|
||||
exclude_kwargs = self.config.exclude_kwargs or []
|
||||
if exclude_kwargs:
|
||||
moved_kwargs = send_to_device(
|
||||
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
|
||||
self.group.onload_device,
|
||||
non_blocking=self.group.non_blocking,
|
||||
)
|
||||
kwargs.update(moved_kwargs)
|
||||
else:
|
||||
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output):
|
||||
@@ -455,6 +472,8 @@ def apply_group_offloading(
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
block_modules: Optional[List[str]] = None,
|
||||
exclude_kwargs: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
@@ -512,6 +531,13 @@ def apply_group_offloading(
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
block_modules (`List[str]`, *optional*):
|
||||
List of module names that should be treated as blocks for offloading. If provided, only these modules will
|
||||
be considered for block-level offloading. If not provided, the default block detection logic will be used.
|
||||
exclude_kwargs (`List[str]`, *optional*):
|
||||
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
|
||||
caching lists that need to maintain their object identity across forward passes. If not provided, will be
|
||||
inferred from the module's `_skip_keys` attribute if it exists.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -553,6 +579,12 @@ def apply_group_offloading(
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
if block_modules is None:
|
||||
block_modules = getattr(module, "_group_offload_block_modules", None)
|
||||
|
||||
if exclude_kwargs is None:
|
||||
exclude_kwargs = getattr(module, "_skip_keys", None)
|
||||
|
||||
config = GroupOffloadingConfig(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
@@ -563,6 +595,8 @@ def apply_group_offloading(
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
block_modules=block_modules,
|
||||
exclude_kwargs=exclude_kwargs,
|
||||
)
|
||||
_apply_group_offloading(module, config)
|
||||
|
||||
@@ -578,46 +612,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf
|
||||
|
||||
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
||||
"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
|
||||
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
|
||||
done at the top-level blocks and modules specified in block_modules.
|
||||
|
||||
When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
|
||||
module, recursively apply block offloading to it.
|
||||
"""
|
||||
if config.stream is not None and config.num_blocks_per_group != 1:
|
||||
logger.warning(
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
||||
)
|
||||
config.num_blocks_per_group = 1
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
block_modules = set(config.block_modules) if config.block_modules is not None else set()
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
matched_module_groups = []
|
||||
for name, submodule in module.named_children():
|
||||
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
unmatched_modules.append((name, submodule))
|
||||
modules_with_group_offloading.add(name)
|
||||
continue
|
||||
|
||||
for i in range(0, len(submodule), config.num_blocks_per_group):
|
||||
current_modules = submodule[i : i + config.num_blocks_per_group]
|
||||
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
group_id=group_id,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
for j in range(i, i + len(current_modules)):
|
||||
modules_with_group_offloading.add(f"{name}.{j}")
|
||||
for name, submodule in module.named_children():
|
||||
# Check if this is an explicitly defined block module
|
||||
if name in block_modules:
|
||||
# Track submodule using a prefix to avoid filename collisions during disk offload.
|
||||
# Without this, submodules sharing the same model class would be assigned identical
|
||||
# filenames (derived from the class name).
|
||||
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
|
||||
submodule_config = replace(config, module_prefix=prefix)
|
||||
|
||||
_apply_group_offloading_block_level(submodule, submodule_config)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
# Handle ModuleList and Sequential blocks as before
|
||||
for i in range(0, len(submodule), config.num_blocks_per_group):
|
||||
current_modules = list(submodule[i : i + config.num_blocks_per_group])
|
||||
if len(current_modules) == 0:
|
||||
continue
|
||||
|
||||
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
group_id=group_id,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
for j in range(i, i + len(current_modules)):
|
||||
modules_with_group_offloading.add(f"{name}.{j}")
|
||||
else:
|
||||
# This is an unmatched module
|
||||
unmatched_modules.append((name, submodule))
|
||||
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
@@ -632,28 +686,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
parameters = [param for _, param in parameters]
|
||||
buffers = [buffer for _, buffer in buffers]
|
||||
|
||||
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
|
||||
# device when the forward pass is called.
|
||||
# Create a group for the remaining unmatched submodules of the top-level
|
||||
# module so that they are on the correct device when the forward pass is called.
|
||||
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
group_id=f"{module.__class__.__name__}_unmatched_group",
|
||||
)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
||||
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
|
||||
)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
|
||||
@@ -74,6 +74,7 @@ class AutoencoderKL(
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
||||
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -619,6 +619,7 @@ class WanEncoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
|
||||
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
|
||||
# these are shared mutable state modified in-place
|
||||
_skip_keys = ["feat_cache", "feat_idx"]
|
||||
@@ -1414,6 +1416,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
|
||||
@@ -531,6 +531,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage=False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
block_modules: Optional[str] = None,
|
||||
exclude_kwargs: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates group offloading for the current model.
|
||||
@@ -570,6 +572,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
|
||||
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
apply_group_offloading(
|
||||
module=self,
|
||||
onload_device=onload_device,
|
||||
@@ -581,6 +584,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
block_modules=block_modules,
|
||||
exclude_kwargs=exclude_kwargs,
|
||||
)
|
||||
|
||||
def set_attention_backend(self, backend: str) -> None:
|
||||
|
||||
@@ -19,6 +19,7 @@ import unittest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.hooks import HookRegistry, ModelHook
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
@@ -149,6 +150,74 @@ class LayerOutputTrackerHook(ModelHook):
|
||||
return output
|
||||
|
||||
|
||||
# Model with only standalone computational layers at top level
|
||||
class DummyModelWithStandaloneLayers(ModelMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.layer2 = torch.nn.Linear(hidden_features, hidden_features)
|
||||
self.layer3 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.layer1(x)
|
||||
x = self.activation(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
return x
|
||||
|
||||
|
||||
# Model with deeply nested structure
|
||||
class DummyModelWithDeeplyNestedBlocks(ModelMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_layer = torch.nn.Linear(in_features, hidden_features)
|
||||
self.container = ContainerWithNestedModuleList(hidden_features)
|
||||
self.output_layer = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.input_layer(x)
|
||||
x = self.container(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class ContainerWithNestedModuleList(torch.nn.Module):
|
||||
def __init__(self, features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Top-level computational layer
|
||||
self.proj_in = torch.nn.Linear(features, features)
|
||||
|
||||
# Nested container with ModuleList
|
||||
self.nested_container = NestedContainer(features)
|
||||
|
||||
# Another top-level computational layer
|
||||
self.proj_out = torch.nn.Linear(features, features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj_in(x)
|
||||
x = self.nested_container(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class NestedContainer(torch.nn.Module):
|
||||
def __init__(self, features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)])
|
||||
self.norm = torch.nn.LayerNorm(features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class GroupOffloadTests(unittest.TestCase):
|
||||
in_features = 64
|
||||
@@ -340,7 +409,7 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
out = model(x)
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
|
||||
|
||||
num_repeats = 4
|
||||
num_repeats = 2
|
||||
for i in range(num_repeats):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
@@ -362,3 +431,138 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
self.assertLess(
|
||||
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
|
||||
)
|
||||
|
||||
def test_vae_like_model_without_streams(self):
|
||||
"""Test VAE-like model with block-level offloading but without streams."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
config = self.get_autoencoder_kl_config()
|
||||
model = AutoencoderKL(**config)
|
||||
|
||||
model_ref = AutoencoderKL(**config)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False)
|
||||
|
||||
x = torch.randn(2, 3, 32, 32).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
|
||||
)
|
||||
|
||||
def test_model_with_only_standalone_layers(self):
|
||||
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
|
||||
|
||||
model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
x = torch.randn(2, 64).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match at iteration {i} for model with standalone layers.",
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level",), ("leaf_level",)])
|
||||
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
|
||||
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
config = self.get_autoencoder_kl_config()
|
||||
model = AutoencoderKL(**config)
|
||||
|
||||
model_ref = AutoencoderKL(**config)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
x = torch.randn(2, 3, 32, 32).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match for standalone Conv layers with {offload_type}.",
|
||||
)
|
||||
|
||||
def test_multiple_invocations_with_vae_like_model(self):
|
||||
"""Test that multiple forward passes work correctly with VAE-like model."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
config = self.get_autoencoder_kl_config()
|
||||
model = AutoencoderKL(**config)
|
||||
|
||||
model_ref = AutoencoderKL(**config)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
x = torch.randn(2, 3, 32, 32).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
|
||||
|
||||
def test_nested_container_parameters_offloading(self):
|
||||
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
|
||||
|
||||
model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
x = torch.randn(2, 64).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match at iteration {i} for nested parameters.",
|
||||
)
|
||||
|
||||
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
init_dict = {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
"layers_per_block": 1,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
@@ -1791,7 +1791,6 @@ class ModelTesterMixin:
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
output_without_group_offloading = normalize_output(output_without_group_offloading)
|
||||
@@ -1916,6 +1915,9 @@ class ModelTesterMixin:
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
block_modules=model._group_offload_block_modules
|
||||
if hasattr(model, "_group_offload_block_modules")
|
||||
else None,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
|
||||
@@ -1424,6 +1424,8 @@ if is_torch_available():
|
||||
offload_to_disk_path: str,
|
||||
offload_type: str,
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
block_modules: Optional[List[str]] = None,
|
||||
module_prefix: str = "",
|
||||
) -> Set[str]:
|
||||
expected_files = set()
|
||||
|
||||
@@ -1435,23 +1437,36 @@ if is_torch_available():
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
|
||||
|
||||
# Handle groups of ModuleList and Sequential blocks
|
||||
block_modules_set = set(block_modules) if block_modules is not None else set()
|
||||
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
for name, submodule in module.named_children():
|
||||
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
unmatched_modules.append(module)
|
||||
continue
|
||||
if name in block_modules_set:
|
||||
new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}."
|
||||
submodule_files = _get_expected_safetensors_files(
|
||||
submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix
|
||||
)
|
||||
expected_files.update(submodule_files)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
if not current_modules:
|
||||
continue
|
||||
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
|
||||
expected_files.add(get_hashed_filename(group_id))
|
||||
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
if not current_modules:
|
||||
continue
|
||||
group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
|
||||
expected_files.add(get_hashed_filename(group_id))
|
||||
for j in range(i, i + len(current_modules)):
|
||||
modules_with_group_offloading.add(f"{name}.{j}")
|
||||
else:
|
||||
unmatched_modules.append(submodule)
|
||||
|
||||
# Handle the group for unmatched top-level modules and parameters
|
||||
for module in unmatched_modules:
|
||||
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
|
||||
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||||
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||||
|
||||
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
|
||||
expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group"))
|
||||
|
||||
elif offload_type == "leaf_level":
|
||||
# Handle leaf-level module groups
|
||||
@@ -1492,12 +1507,13 @@ if is_torch_available():
|
||||
offload_to_disk_path: str,
|
||||
offload_type: str,
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
block_modules: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
if not os.path.isdir(offload_to_disk_path):
|
||||
return False, None, None
|
||||
|
||||
expected_files = _get_expected_safetensors_files(
|
||||
module, offload_to_disk_path, offload_type, num_blocks_per_group
|
||||
module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules
|
||||
)
|
||||
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
|
||||
missing_files = expected_files - actual_files
|
||||
|
||||
Reference in New Issue
Block a user