mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add tests
This commit is contained in:
@@ -1425,10 +1425,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
from ..hooks import apply_group_offloading
|
||||
|
||||
if exclude_modules is not None and isinstance(exclude_modules, str):
|
||||
if isinstance(exclude_modules, str):
|
||||
exclude_modules = [exclude_modules]
|
||||
elif exclude_modules is None:
|
||||
exclude_modules = []
|
||||
|
||||
unknown = set(exclude_modules) - set(self.components.keys())
|
||||
unknown = set(exclude_modules) - self.components.keys()
|
||||
if unknown:
|
||||
logger.info(
|
||||
f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import ModelCard, delete_repo
|
||||
@@ -2362,6 +2363,73 @@ class PipelineTesterMixin:
|
||||
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_pipeline_level_group_offloading_sanity_checks(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: DiffusionPipeline = self.pipeline_class(**components)
|
||||
|
||||
for name, component in pipe.components.items():
|
||||
if hasattr(component, "_supports_group_offloading"):
|
||||
if not component._supports_group_offloading:
|
||||
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
|
||||
|
||||
module_names = sorted(
|
||||
[name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]
|
||||
)
|
||||
exclude_module_name = module_names[0]
|
||||
offload_device = "cpu"
|
||||
pipe.enable_group_offload(
|
||||
onload_device=torch_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
exclude_modules=exclude_module_name,
|
||||
)
|
||||
excluded_module = getattr(pipe, exclude_module_name)
|
||||
self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)
|
||||
|
||||
for name, component in pipe.components.items():
|
||||
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
|
||||
# `component.device` prints the `onload_device` type. We should probably override the
|
||||
# `device` property in `ModelMixin`.
|
||||
component_device = next(component.parameters())[0].device
|
||||
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe: DiffusionPipeline = self.pipeline_class(**components)
|
||||
|
||||
for name, component in pipe.components.items():
|
||||
if hasattr(component, "_supports_group_offloading"):
|
||||
if not component._supports_group_offloading:
|
||||
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
|
||||
|
||||
# Regular inference.
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = torch.manual_seed(0)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
pipe.to("cpu")
|
||||
del pipe
|
||||
|
||||
# Inference with offloading
|
||||
pipe: DiffusionPipeline = self.pipeline_class(**components)
|
||||
offload_device = "cpu"
|
||||
pipe.enable_group_offload(
|
||||
onload_device=torch_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs["generator"] = torch.manual_seed(0)
|
||||
out_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class PipelinePushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user