1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Support tiled encode/decode for AutoencoderTiny (#4627)

* Impl tae slicing and tiling

* add tae tiling test

* add parameterized test

* formatted code

* fix failed test

* style docs
This commit is contained in:
Isotr0py
2023-08-18 11:42:55 +08:00
committed by GitHub
parent a10107f92b
commit 67ea2b7afa
2 changed files with 168 additions and 2 deletions

View File

@@ -137,6 +137,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self.latent_shift = latent_shift
self.scaling_factor = scaling_factor
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.spatial_scale_factor = 2**out_channels
self.tile_overlap_factor = 0.125
self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value
@@ -149,11 +158,147 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
plain `tuple` is returned.
"""
# scale of encoder output relative to input
sf = self.spatial_scale_factor
tile_size = self.tile_sample_min_size
# number of pixels to blend and to traverse between tile
blend_size = int(tile_size * self.tile_overlap_factor)
traverse_size = tile_size - blend_size
# tiles index (up/left)
ti = range(0, x.shape[-2], traverse_size)
tj = range(0, x.shape[-1], traverse_size)
# mask for blending
blend_masks = torch.stack(
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
)
blend_masks = blend_masks.clamp(0, 1).to(x.device)
# output array
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
for i in ti:
for j in tj:
tile_in = x[..., i : i + tile_size, j : j + tile_size]
# tile result
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
tile = self.encoder(tile_in)
h, w = tile.shape[-2], tile.shape[-1]
# blend tile result into output
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
blend_mask = blend_mask_i * blend_mask_j
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
return out
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
# scale of decoder output relative to input
sf = self.spatial_scale_factor
tile_size = self.tile_latent_min_size
# number of pixels to blend and to traverse between tiles
blend_size = int(tile_size * self.tile_overlap_factor)
traverse_size = tile_size - blend_size
# tiles index (up/left)
ti = range(0, x.shape[-2], traverse_size)
tj = range(0, x.shape[-1], traverse_size)
# mask for blending
blend_masks = torch.stack(
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
)
blend_masks = blend_masks.clamp(0, 1).to(x.device)
# output array
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
for i in ti:
for j in tj:
tile_in = x[..., i : i + tile_size, j : j + tile_size]
# tile result
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
tile = self.decoder(tile_in)
h, w = tile.shape[-2], tile.shape[-1]
# blend tile result into output
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
return out
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
output = self.encoder(x)
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
output = torch.cat(output)
else:
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
if not return_dict:
return (output,)
@@ -162,7 +307,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
@apply_forward_hook
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
output = self.decoder(x)
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
output = torch.cat(output)
else:
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
# Refer to the following discussion to know why this is needed.
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
output = output.mul_(2).sub_(1)

View File

@@ -285,6 +285,23 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
model.to(torch_device).eval()
return model
@parameterized.expand(
[
[(1, 4, 73, 97), (1, 3, 584, 776)],
[(1, 4, 97, 73), (1, 3, 776, 584)],
[(1, 4, 49, 65), (1, 3, 392, 520)],
[(1, 4, 65, 49), (1, 3, 520, 392)],
[(1, 4, 49, 49), (1, 3, 392, 392)],
]
)
def test_tae_tiling(self, in_shape, out_shape):
model = self.get_sd_vae_model()
model.enable_tiling()
with torch.no_grad():
zeros = torch.zeros(in_shape).to(torch_device)
dec = model.decode(zeros).sample
assert dec.shape == out_shape
def test_stable_diffusion(self):
model = self.get_sd_vae_model()
image = self.get_sd_image(seed=33)