From ca1e40726e0e6b770e62ba98e4c6f74ea6bf9647 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 21 Mar 2023 10:18:44 -0700 Subject: [PATCH] stable diffusion depth batching fix (#2757) --- .../stable_diffusion/pipeline_stable_diffusion_depth2img.py | 3 ++- .../stable_diffusion_2/test_stable_diffusion_depth.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 9087064ae0..b66cfe9b43 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -474,7 +474,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if depth_map.shape[0] < batch_size: - depth_map = depth_map.repeat(batch_size, 1, 1, 1) + repeat_by = batch_size // depth_map.shape[0] + depth_map = depth_map.repeat(repeat_by, 1, 1, 1) depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map return depth_map diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 110dbbd7f8..12e7113399 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te test_save_load_optional_components = False params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS def get_dummy_components(self): torch.manual_seed(0)