mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix group offloading with block_level and use_stream=True (#11375)
* fix * add tests * add message check
This commit is contained in:
@@ -57,7 +57,7 @@ class ModuleGroup:
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage=False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -498,6 +498,8 @@ def _apply_group_offloading_block_level(
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
if stream is not None and num_blocks_per_group != 1:
|
||||
raise ValueError(f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}.")
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -521,7 +523,7 @@ def _apply_group_offloading_block_level(
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=stream is None,
|
||||
onload_self=True,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
for j in range(i, i + len(current_modules)):
|
||||
@@ -529,12 +531,8 @@ def _apply_group_offloading_block_level(
|
||||
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
next_group = (
|
||||
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
|
||||
)
|
||||
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, next_group)
|
||||
_apply_group_offloading_hook(group_module, group, None)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
@@ -560,8 +558,10 @@ def _apply_group_offloading_block_level(
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
_apply_group_offloading_hook(module, unmatched_group, next_group)
|
||||
if stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
@@ -20,6 +21,7 @@ import torch
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils import get_logger
|
||||
from diffusers.utils.import_utils import compare_versions
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
||||
|
||||
|
||||
@@ -58,6 +60,39 @@ class DummyModel(ModelMixin):
|
||||
return x
|
||||
|
||||
|
||||
# This model implementation contains one type of block (single_blocks) instantiated before another type of block (double_blocks).
|
||||
# The invocation order of these blocks, however, is first the double_blocks and then the single_blocks.
|
||||
# With group offloading implementation before https://github.com/huggingface/diffusers/pull/11375, such a modeling implementation
|
||||
# would result in a device mismatch error because of the assumptions made by the code. The failure case occurs when using:
|
||||
# offload_type="block_level", num_blocks_per_group=2, use_stream=True
|
||||
# Post the linked PR, the implementation will work as expected.
|
||||
class DummyModelWithMultipleBlocks(ModelMixin):
|
||||
def __init__(
|
||||
self, in_features: int, hidden_features: int, out_features: int, num_layers: int, num_single_layers: int
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.single_blocks = torch.nn.ModuleList(
|
||||
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_single_layers)]
|
||||
)
|
||||
self.double_blocks = torch.nn.ModuleList(
|
||||
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
|
||||
)
|
||||
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = self.activation(x)
|
||||
for block in self.double_blocks:
|
||||
x = block(x)
|
||||
for block in self.single_blocks:
|
||||
x = block(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyPipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "model"
|
||||
|
||||
@@ -212,3 +247,23 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
|
||||
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
|
||||
if torch.device(torch_device).type != "cuda":
|
||||
return
|
||||
model = DummyModelWithMultipleBlocks(
|
||||
in_features=self.in_features,
|
||||
hidden_features=self.hidden_features,
|
||||
out_features=self.out_features,
|
||||
num_layers=self.num_layers,
|
||||
num_single_layers=self.num_layers + 1,
|
||||
)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
context = contextlib.nullcontext()
|
||||
if compare_versions("diffusers", "<=", "0.33.0"):
|
||||
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
|
||||
context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device")
|
||||
|
||||
with context:
|
||||
model(self.input)
|
||||
|
||||
Reference in New Issue
Block a user