diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 070183b924..da6832cebd 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -199,6 +199,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): image = image.to(device=device) image_latents = self.vae.encode(image).latent_dist.mode() + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: negative_image_latents = torch.zeros_like(image_latents) @@ -207,9 +210,6 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): # to avoid doing two forward passes image_latents = torch.cat([negative_image_latents, image_latents]) - # duplicate image_latents for each generation per prompt, using mps friendly method - image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) - return image_latents def _get_add_time_ids(