mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[bug] Precedence of operations in VAE should be slicing -> tiling (#9342)
* bugfix: precedence of operations should be slicing -> tiling * fix typo * fix another typo * deprecate current implementation of tiled_encode and use new impl * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -245,6 +246,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self._tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
enc = self.quant_conv(enc)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
@@ -261,21 +274,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x, return_dict=return_dict)
|
||||
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self.encoder(x)
|
||||
h = self._encode(x)
|
||||
|
||||
if self.quant_conv is not None:
|
||||
moments = self.quant_conv(h)
|
||||
else:
|
||||
moments = h
|
||||
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
@@ -337,6 +342,54 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
if self.config.use_quant_conv:
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
enc = torch.cat(result_rows, dim=2)
|
||||
return enc
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
@@ -356,6 +409,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||
`tuple` is returned.
|
||||
"""
|
||||
deprecation_message = (
|
||||
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
|
||||
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
|
||||
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
|
||||
)
|
||||
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
Reference in New Issue
Block a user