From e7f3a7378677a8a43cfaf0dd9665e5cfcd22aba5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Apr 2025 23:48:50 +0530 Subject: [PATCH] Fix Wan I2V prepare_latents dtype (#11371) update --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 86d5496d16..10fa7b55c3 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -409,7 +409,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], dim=2, ) - video_condition = video_condition.to(device=device, dtype=dtype) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -429,6 +429,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = latent_condition.to(dtype) latent_condition = (latent_condition - latents_mean) * latents_std mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)