diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 608de25da5..d84a0861e9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -913,38 +913,21 @@ def patchify(x, patch_size): if patch_size == 1: return x - if x.dim() == 4: - # x shape: [batch_size, channels, height, width] - batch_size, channels, height, width = x.shape - - # Ensure height and width are divisible by patch_size - if height % patch_size != 0 or width % patch_size != 0: - raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") - - # Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size] - x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size) - - # Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size] - x = x.permute(0, 1, 3, 5, 2, 4).contiguous() - x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size) - - elif x.dim() == 5: - # x shape: [batch_size, channels, frames, height, width] - batch_size, channels, frames, height, width = x.shape - - # Ensure height and width are divisible by patch_size - if height % patch_size != 0 or width % patch_size != 0: - raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") - - # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size] - x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size) - - # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size] - x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous() - x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size) - - else: + if x.dim() != 5: raise ValueError(f"Invalid input shape: {x.shape}") + # x shape: [batch_size, channels, frames, height, width] + batch_size, channels, frames, height, width = x.shape + + # Ensure height and width are divisible by patch_size + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") + + # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size] + x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size) + + # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size] + x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size) return x @@ -953,29 +936,18 @@ def unpatchify(x, patch_size): if patch_size == 1: return x - if x.dim() == 4: - # x shape: [b, (c * patch_size * patch_size), h, w] - batch_size, c_patches, height, width = x.shape - channels = c_patches // (patch_size * patch_size) + if x.dim() != 5: + raise ValueError(f"Invalid input shape: {x.shape}") + # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width] + batch_size, c_patches, frames, height, width = x.shape + channels = c_patches // (patch_size * patch_size) - # Reshape to [b, c, patch_size, patch_size, h, w] - x = x.view(batch_size, channels, patch_size, patch_size, height, width) + # Reshape to [b, c, patch_size, patch_size, f, h, w] + x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width) - # Rearrange to [b, c, h * patch_size, w * patch_size] - x = x.permute(0, 1, 4, 2, 5, 3).contiguous() - x = x.view(batch_size, channels, height * patch_size, width * patch_size) - - elif x.dim() == 5: - # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width] - batch_size, c_patches, frames, height, width = x.shape - channels = c_patches // (patch_size * patch_size) - - # Reshape to [b, c, patch_size, patch_size, f, h, w] - x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width) - - # Rearrange to [b, c, f, h * patch_size, w * patch_size] - x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous() - x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size) + # Rearrange to [b, c, f, h * patch_size, w * patch_size] + x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous() + x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size) return x @@ -1044,7 +1016,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): patch_size: Optional[int] = None, scale_factor_temporal: Optional[int] = 4, scale_factor_spatial: Optional[int] = 8, - clip_output: bool = True, ) -> None: super().__init__() @@ -1244,10 +1215,11 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) - if self.config.clip_output: - out = torch.clamp(out, min=-1.0, max=1.0) if self.config.patch_size is not None: out = unpatchify(out, patch_size=self.config.patch_size) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() if not return_dict: return (out,)