1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

stable diffusion depth batching fix (#2757)

This commit is contained in:
Will Berman
2023-03-21 10:18:44 -07:00
committed by GitHub
parent b33bd91fae
commit ca1e40726e
2 changed files with 3 additions and 2 deletions

View File

@@ -474,7 +474,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if depth_map.shape[0] < batch_size:
depth_map = depth_map.repeat(batch_size, 1, 1, 1)
repeat_by = batch_size // depth_map.shape[0]
depth_map = depth_map.repeat(repeat_by, 1, 1, 1)
depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map
return depth_map

View File

@@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
test_save_load_optional_components = False
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)