1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

address reviewer feedback.

This commit is contained in:
sayakpaul
2025-09-16 12:19:11 +05:30
parent 1d9bf41cf9
commit 43cae1a613

View File

@@ -14,6 +14,7 @@ import torch
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
@@ -2319,8 +2320,9 @@ class PipelineTesterMixin:
component_device = next(component.parameters())[0].device
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
@parameterized.expand([("block_level"), ("leaf_level")])
@require_torch_accelerator
def test_group_offloading_inference(self):
def test_group_offloading_inference(self, offload_type: str = "block_level"):
if not self.test_group_offloading:
pytest.skip("`test_group_offloading` is disabled hence skipping.")
@@ -2331,26 +2333,24 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
return pipe
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
def enable_group_offload(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
exclude_modules = []
for name, component in pipe.components.items():
for name in ["vae", "vqvae", "image_encoder"]:
exclude_modules.append(name)
exclude_modules = {"vae", "vqvae", "image_encoder"}
exclude_modules = list(exclude_modules & set(pipe.components.keys()))
pipe.enable_group_offload(
exclude_modules=exclude_modules, onload_device=torch_device, **group_offloading_kwargs
)
for component_name, component in pipe.components.items():
if component_name not in exclude_modules and isinstance(component, torch.nn.Module):
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
if component_name in exclude_modules:
continue
elif isinstance(component, torch.nn.Module):
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
def run_forward(pipe):
@@ -2362,20 +2362,19 @@ class PipelineTesterMixin:
output_without_group_offloading = run_forward(pipe)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
output_with_group_offloading1 = run_forward(pipe)
if offload_type == "block_level":
offloading_kwargs = {"offload_type": "block_level", "num_blocks_per_group": 1}
else:
offloading_kwargs = {"offload_type": "leaf_level"}
enable_group_offload(pipe, offloading_kwargs)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
output_with_group_offloading2 = run_forward(pipe)
output_with_group_offloading = run_forward(pipe)
if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()
output_with_group_offloading = output_with_group_offloading.detach().cpu().numpy()
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
assert np.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4)
@is_staging_test