mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Wan] Fix VAE sampling mode in WanVideoToVideoPipeline (#11639)
* fix: vae sampling mode * fix a typo
This commit is contained in:
@@ -419,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
|
||||
@@ -441,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
if hasattr(self.scheduler, "add_noise"):
|
||||
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
else:
|
||||
latents = self.scheduelr.scale_noise(init_latents, timestep, noise)
|
||||
latents = self.scheduler.scale_noise(init_latents, timestep, noise)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user