diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 3f4dd19c9b..947d57a7be 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -230,10 +230,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) pipe_2.unet.set_default_attn_processor() - def assert_run_mixture(num_steps, split, scheduler_cls): + def assert_run_mixture(num_steps, split, scheduler_cls_orig): inputs = self.get_dummy_inputs(torch_device) inputs["num_inference_steps"] = num_steps + class scheduler_cls(scheduler_cls_orig): + pass + pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) @@ -287,10 +290,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) pipe_3.unet.set_default_attn_processor() - def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls): + def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls_orig): inputs = self.get_dummy_inputs(torch_device) inputs["num_inference_steps"] = num_steps + class scheduler_cls(scheduler_cls_orig): + pass + pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) pipe_3.scheduler = scheduler_cls.from_config(pipe_3.scheduler.config)