diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 4794a166bc..8a476e4287 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn from huggingface_hub import ModelCard, delete_repo from huggingface_hub.utils import is_jinja_available +from parameterized import parameterized from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer import diffusers @@ -2319,8 +2320,9 @@ class PipelineTesterMixin: component_device = next(component.parameters())[0].device self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type) + @parameterized.expand([("block_level"), ("leaf_level")]) @require_torch_accelerator - def test_group_offloading_inference(self): + def test_group_offloading_inference(self, offload_type: str = "block_level"): if not self.test_group_offloading: pytest.skip("`test_group_offloading` is disabled hence skipping.") @@ -2331,26 +2333,24 @@ class PipelineTesterMixin: pipe.set_progress_bar_config(disable=None) return pipe - def enable_group_offload_on_component(pipe, group_offloading_kwargs): + def enable_group_offload(pipe, group_offloading_kwargs): # We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If # tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of # the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a # warmup forward pass (even with dummy small inputs) is recommended. - exclude_modules = [] - for name, component in pipe.components.items(): - for name in ["vae", "vqvae", "image_encoder"]: - exclude_modules.append(name) + exclude_modules = {"vae", "vqvae", "image_encoder"} + exclude_modules = list(exclude_modules & set(pipe.components.keys())) pipe.enable_group_offload( exclude_modules=exclude_modules, onload_device=torch_device, **group_offloading_kwargs ) for component_name, component in pipe.components.items(): - if component_name not in exclude_modules and isinstance(component, torch.nn.Module): - self.assertTrue( - all( - module._diffusers_hook.get_hook("group_offloading") is not None - for module in component.modules() - if hasattr(module, "_diffusers_hook") - ) + if component_name in exclude_modules: + continue + elif isinstance(component, torch.nn.Module): + assert all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in component.modules() + if hasattr(module, "_diffusers_hook") ) def run_forward(pipe): @@ -2362,20 +2362,19 @@ class PipelineTesterMixin: output_without_group_offloading = run_forward(pipe) pipe = create_pipe() - enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1}) - output_with_group_offloading1 = run_forward(pipe) + if offload_type == "block_level": + offloading_kwargs = {"offload_type": "block_level", "num_blocks_per_group": 1} + else: + offloading_kwargs = {"offload_type": "leaf_level"} + enable_group_offload(pipe, offloading_kwargs) - pipe = create_pipe() - enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"}) - output_with_group_offloading2 = run_forward(pipe) + output_with_group_offloading = run_forward(pipe) if torch.is_tensor(output_without_group_offloading): output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy() - output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy() - output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy() + output_with_group_offloading = output_with_group_offloading.detach().cpu().numpy() - self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4)) - self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4)) + assert np.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4) @is_staging_test