From e27142ac644b1ed77d9d60c55432fe74659520db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Wed, 11 Jun 2025 11:49:23 +0300 Subject: [PATCH] [`Wan`] Fix VAE sampling mode in `WanVideoToVideoPipeline` (#11639) * fix: vae sampling mode * fix a typo --- src/diffusers/pipelines/wan/pipeline_wan_video2video.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index 1844f1b49b..a4a10d4655 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -419,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ) if latents is None: - if isinstance(generator, list): - init_latents = [ - retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) - ] - else: - init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video] init_latents = torch.cat(init_latents, dim=0).to(dtype) @@ -441,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): if hasattr(self.scheduler, "add_noise"): latents = self.scheduler.add_noise(init_latents, noise, timestep) else: - latents = self.scheduelr.scale_noise(init_latents, timestep, noise) + latents = self.scheduler.scale_noise(init_latents, timestep, noise) else: latents = latents.to(device)