mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SVD] support generators that are created on GPU (#6484)
* debug generator * fix? * fix? * fix * remove print. * revert none check
This commit is contained in:
@@ -429,15 +429,20 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
fps = fps - 1
|
||||
|
||||
# 4. Encode input image using VAE
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
|
||||
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
|
||||
image = image + noise_aug_strength * noise
|
||||
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
||||
image_latents = self._encode_vae_image(
|
||||
image,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
image_latents = image_latents.to(image_embeddings.dtype)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
|
||||
Reference in New Issue
Block a user