1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix crash if tiling mode is enabled (#12521)

* fix crash in tiling mode is enabled

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fmt

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Wang, Yi
2025-10-28 11:59:20 +08:00
committed by GitHub
parent ecfbc8f952
commit dc622a95d0

View File

@@ -1337,9 +1337,18 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
tile_sample_stride_height = self.tile_sample_stride_height
tile_sample_stride_width = self.tile_sample_stride_width
if self.config.patch_size is not None:
sample_height = sample_height // self.config.patch_size
sample_width = sample_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
else:
blend_height = self.tile_sample_min_height - tile_sample_stride_height
blend_width = self.tile_sample_min_width - tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
@@ -1353,7 +1362,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
decoded = self.decoder(
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1369,11 +1380,15 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if self.config.patch_size is not None:
dec = unpatchify(dec, patch_size=self.config.patch_size)
dec = torch.clamp(dec, min=-1.0, max=1.0)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)