From e141f5cfd0bccf692920bef9ff54642864432db8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Sep 2025 11:49:54 +0530 Subject: [PATCH] add tests --- src/diffusers/pipelines/pipeline_utils.py | 6 +- tests/pipelines/test_pipelines_common.py | 68 +++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a6e87ca85f..04e33e655a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1425,10 +1425,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): """ from ..hooks import apply_group_offloading - if exclude_modules is not None and isinstance(exclude_modules, str): + if isinstance(exclude_modules, str): exclude_modules = [exclude_modules] + elif exclude_modules is None: + exclude_modules = [] - unknown = set(exclude_modules) - set(self.components.keys()) + unknown = set(exclude_modules) - self.components.keys() if unknown: logger.info( f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected." diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index dcef33897e..db8209835b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Union import numpy as np import PIL.Image +import pytest import torch import torch.nn as nn from huggingface_hub import ModelCard, delete_repo @@ -2362,6 +2363,73 @@ class PipelineTesterMixin: max_diff = np.abs(to_np(out) - to_np(loaded_out)).max() self.assertLess(max_diff, expected_max_difference) + @require_torch_accelerator + def test_pipeline_level_group_offloading_sanity_checks(self): + components = self.get_dummy_components() + pipe: DiffusionPipeline = self.pipeline_class(**components) + + for name, component in pipe.components.items(): + if hasattr(component, "_supports_group_offloading"): + if not component._supports_group_offloading: + pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.") + + module_names = sorted( + [name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)] + ) + exclude_module_name = module_names[0] + offload_device = "cpu" + pipe.enable_group_offload( + onload_device=torch_device, + offload_device=offload_device, + offload_type="leaf_level", + exclude_modules=exclude_module_name, + ) + excluded_module = getattr(pipe, exclude_module_name) + self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type) + + for name, component in pipe.components.items(): + if name not in [exclude_module_name] and isinstance(component, torch.nn.Module): + # `component.device` prints the `onload_device` type. We should probably override the + # `device` property in `ModelMixin`. + component_device = next(component.parameters())[0].device + self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type) + + @require_torch_accelerator + def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + pipe: DiffusionPipeline = self.pipeline_class(**components) + + for name, component in pipe.components.items(): + if hasattr(component, "_supports_group_offloading"): + if not component._supports_group_offloading: + pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.") + + # Regular inference. + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + inputs["generator"] = torch.manual_seed(0) + out = pipe(**inputs)[0] + + pipe.to("cpu") + del pipe + + # Inference with offloading + pipe: DiffusionPipeline = self.pipeline_class(**components) + offload_device = "cpu" + pipe.enable_group_offload( + onload_device=torch_device, + offload_device=offload_device, + offload_type="leaf_level", + ) + pipe.set_progress_bar_config(disable=None) + inputs["generator"] = torch.manual_seed(0) + out_offload = pipe(**inputs)[0] + + max_diff = np.abs(to_np(out) - to_np(out_offload)).max() + self.assertLess(max_diff, expected_max_difference) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase):