mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add tests
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Callable, Union
|
||||
|
||||
@@ -8,9 +10,16 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
@@ -335,3 +344,53 @@ class ModularGuiderTesterMixin:
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = torch.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def test_custom_requirements_save_load(self):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "modular_config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_warnings(self):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
Reference in New Issue
Block a user