From a38dd795120e1884e3396d41bf44e44fd9b1eba0 Mon Sep 17 00:00:00 2001 From: Yushu Date: Mon, 29 Apr 2024 03:54:16 -0700 Subject: [PATCH] [Pipeline] Fix error of SVD pipeline when num_videos_per_prompt > 1 (#7786) swap the order for do_classifier_free_guidance concat with repeat Co-authored-by: Sayak Paul Co-authored-by: Dhruv Nair --- .../pipeline_stable_video_diffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 070183b924..da6832cebd 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -199,6 +199,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): image = image.to(device=device) image_latents = self.vae.encode(image).latent_dist.mode() + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: negative_image_latents = torch.zeros_like(image_latents) @@ -207,9 +210,6 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): # to avoid doing two forward passes image_latents = torch.cat([negative_image_latents, image_latents]) - # duplicate image_latents for each generation per prompt, using mps friendly method - image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) - return image_latents def _get_add_time_ids(