1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: 36d22d0709/src/diffusers/schedulers/scheduling_unipc_multistep.py (L171)
This commit is contained in:
Pedro Cuenca
2023-07-04 13:30:16 +02:00
parent c6fe8b0b66
commit 981cf960db

View File

@@ -25,11 +25,13 @@ import torch
import diffusers
from diffusers import (
DDIMScheduler,
DEISMultistepScheduler,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
IPNDMScheduler,
LMSDiscreteScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
logging,
)
@@ -229,6 +231,19 @@ class SchedulerBaseTests(unittest.TestCase):
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
assert pipe.scheduler.config.timestep_spacing == "trailing"
def test_default_solver_type_after_switch(self):
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16
)
assert pipe.scheduler.__class__ == DDIMScheduler
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
assert pipe.scheduler.config.solver_type == "logrho"
# Switch to UniPC, verify the solver is the default
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
assert pipe.scheduler.config.solver_type == "bh2"
class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()