1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix Wan I2V prepare_latents dtype (#11371)

update
This commit is contained in:
Aryan
2025-04-21 23:48:50 +05:30
committed by GitHub
parent 7a4a126db8
commit e7f3a73786

View File

@@ -409,7 +409,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
dim=2,
)
video_condition = video_condition.to(device=device, dtype=dtype)
video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
@@ -429,6 +429,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)