From ee2f2775b25ebbbff7314b72c64ed07461fdd4b6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 14 Jul 2023 13:44:44 +0200 Subject: [PATCH] [tests] use parent class for monkey patching to not break other tests (#4088) * [tests] use parent class for monkey patching to not break other tests * fix --- .../stable_diffusion_xl/test_stable_diffusion_xl.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 3f4dd19c9b..947d57a7be 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -230,10 +230,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) pipe_2.unet.set_default_attn_processor() - def assert_run_mixture(num_steps, split, scheduler_cls): + def assert_run_mixture(num_steps, split, scheduler_cls_orig): inputs = self.get_dummy_inputs(torch_device) inputs["num_inference_steps"] = num_steps + class scheduler_cls(scheduler_cls_orig): + pass + pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) @@ -287,10 +290,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) pipe_3.unet.set_default_attn_processor() - def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls): + def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls_orig): inputs = self.get_dummy_inputs(torch_device) inputs["num_inference_steps"] = num_steps + class scheduler_cls(scheduler_cls_orig): + pass + pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) pipe_3.scheduler = scheduler_cls.from_config(pipe_3.scheduler.config)