mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Wan I2V Quality (#11087)
* fix_wan_i2v_quality * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update pipeline_wan_i2v.py --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -108,31 +108,16 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor,
|
||||
latents_mean: torch.Tensor,
|
||||
latents_std: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
sample_mode: str = "sample",
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return (encoder_output.latents - latents_mean) * latents_std
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
@@ -412,13 +397,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [
|
||||
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
|
||||
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
||||
]
|
||||
latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
|
||||
Reference in New Issue
Block a user