mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
stable diffusion depth batching fix (#2757)
This commit is contained in:
@@ -474,7 +474,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if depth_map.shape[0] < batch_size:
|
||||
depth_map = depth_map.repeat(batch_size, 1, 1, 1)
|
||||
repeat_by = batch_size // depth_map.shape[0]
|
||||
depth_map = depth_map.repeat(repeat_by, 1, 1, 1)
|
||||
|
||||
depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map
|
||||
return depth_map
|
||||
|
||||
@@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
test_save_load_optional_components = False
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user