1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] use parent class for monkey patching to not break other tests (#4088)

* [tests] use parent class for monkey patching to not break other tests

* fix
This commit is contained in:
Patrick von Platen
2023-07-14 13:44:44 +02:00
committed by GitHub
parent 692b7a907d
commit ee2f2775b2

View File

@@ -230,10 +230,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
pipe_2.unet.set_default_attn_processor()
def assert_run_mixture(num_steps, split, scheduler_cls):
def assert_run_mixture(num_steps, split, scheduler_cls_orig):
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps
class scheduler_cls(scheduler_cls_orig):
pass
pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
@@ -287,10 +290,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
pipe_3.unet.set_default_attn_processor()
def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls):
def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls_orig):
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps
class scheduler_cls(scheduler_cls_orig):
pass
pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
pipe_3.scheduler = scheduler_cls.from_config(pipe_3.scheduler.config)