From 5dc34713800306f498f8ecff6158fcd3668032be Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 11 Jan 2024 20:08:18 +0530 Subject: [PATCH] [SVD] support generators that are created on GPU (#6484) * debug generator * fix? * fix? * fix * remove print. * revert none check --- .../pipeline_stable_video_diffusion.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index fa96f41cd8..e5360d37c6 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -429,15 +429,20 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): fps = fps - 1 # 4. Encode input image using VAE - image = self.image_processor.preprocess(image, height=height, width=width) - noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) + image = self.image_processor.preprocess(image, height=height, width=width).to(device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) image = image + noise_aug_strength * noise needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.vae.to(dtype=torch.float32) - image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) + image_latents = self._encode_vae_image( + image, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) image_latents = image_latents.to(image_embeddings.dtype) # cast back to fp16 if needed