1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[wan2.2] fix vae patches (#12041)

up
This commit is contained in:
YiYi Xu
2025-07-31 23:43:42 -10:00
committed by GitHub
parent 20e0740b88
commit 58d2b10a2e

View File

@@ -913,38 +913,21 @@ def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
# x shape: [batch_size, channels, height, width]
batch_size, channels, height, width = x.shape
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
# Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
# Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
elif x.dim() == 5:
# x shape: [batch_size, channels, frames, height, width]
batch_size, channels, frames, height, width = x.shape
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
else:
if x.dim() != 5:
raise ValueError(f"Invalid input shape: {x.shape}")
# x shape: [batch_size, channels, frames, height, width]
batch_size, channels, frames, height, width = x.shape
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
return x
@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
# x shape: [b, (c * patch_size * patch_size), h, w]
batch_size, c_patches, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
if x.dim() != 5:
raise ValueError(f"Invalid input shape: {x.shape}")
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
batch_size, c_patches, frames, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
# Reshape to [b, c, patch_size, patch_size, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, height, width)
# Reshape to [b, c, patch_size, patch_size, f, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
# Rearrange to [b, c, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(batch_size, channels, height * patch_size, width * patch_size)
elif x.dim() == 5:
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
batch_size, c_patches, frames, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
# Reshape to [b, c, patch_size, patch_size, f, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
return x
@@ -1044,7 +1016,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
patch_size: Optional[int] = None,
scale_factor_temporal: Optional[int] = 4,
scale_factor_spatial: Optional[int] = 8,
clip_output: bool = True,
) -> None:
super().__init__()
@@ -1244,10 +1215,11 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
if self.config.clip_output:
out = torch.clamp(out, min=-1.0, max=1.0)
if self.config.patch_size is not None:
out = unpatchify(out, patch_size=self.config.patch_size)
out = torch.clamp(out, min=-1.0, max=1.0)
self.clear_cache()
if not return_dict:
return (out,)