1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

clean the pipeline level group offloading tests

This commit is contained in:
sayakpaul
2025-09-15 13:51:57 +05:30
parent f5c113e439
commit 1d9bf41cf9

View File

@@ -32,7 +32,6 @@ from diffusers import (
UNet2DConditionModel,
apply_faster_cache,
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
@@ -2244,80 +2243,6 @@ class PipelineTesterMixin:
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)[0]
@require_torch_accelerator
def test_group_offloading_inference(self):
if not self.test_group_offloading:
return
def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
for component_name in [
"text_encoder",
"text_encoder_2",
"text_encoder_3",
"transformer",
"unet",
"controlnet",
]:
if not hasattr(pipe, component_name):
continue
component = getattr(pipe, component_name)
if not getattr(component, "_supports_group_offloading", True):
continue
if hasattr(component, "enable_group_offload"):
# For diffusers ModelMixin implementations
component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs)
else:
# For other models not part of diffusers
apply_group_offloading(
component, onload_device=torch.device(torch_device), **group_offloading_kwargs
)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
)
for component_name in ["vae", "vqvae", "image_encoder"]:
component = getattr(pipe, component_name, None)
if isinstance(component, torch.nn.Module):
component.to(torch_device)
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]
pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
output_with_group_offloading1 = run_forward(pipe)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
output_with_group_offloading2 = run_forward(pipe)
if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
@@ -2364,7 +2289,7 @@ class PipelineTesterMixin:
self.assertLess(max_diff, expected_max_difference)
@require_torch_accelerator
def test_pipeline_level_group_offloading_sanity_checks(self):
def test_group_offloading_sanity_checks(self):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)
@@ -2395,40 +2320,62 @@ class PipelineTesterMixin:
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
@require_torch_accelerator
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)
def test_group_offloading_inference(self):
if not self.test_group_offloading:
pytest.skip("`test_group_offloading` is disabled hence skipping.")
for name, component in pipe.components.items():
if hasattr(component, "_supports_group_offloading"):
if not component._supports_group_offloading:
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe
# Regular inference.
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = torch.manual_seed(0)
out = pipe(**inputs)[0]
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
exclude_modules = []
for name, component in pipe.components.items():
for name in ["vae", "vqvae", "image_encoder"]:
exclude_modules.append(name)
pipe.enable_group_offload(
exclude_modules=exclude_modules, onload_device=torch_device, **group_offloading_kwargs
)
for component_name, component in pipe.components.items():
if component_name not in exclude_modules and isinstance(component, torch.nn.Module):
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
)
pipe.to("cpu")
del pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]
# Inference with offloading
pipe: DiffusionPipeline = self.pipeline_class(**components)
offload_device = "cpu"
pipe.enable_group_offload(
onload_device=torch_device,
offload_device=offload_device,
offload_type="leaf_level",
)
pipe.set_progress_bar_config(disable=None)
inputs["generator"] = torch.manual_seed(0)
out_offload = pipe(**inputs)[0]
pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)
max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
self.assertLess(max_diff, expected_max_difference)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
output_with_group_offloading1 = run_forward(pipe)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
output_with_group_offloading2 = run_forward(pipe)
if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
@is_staging_test