From 67ea2b7afaaa6f20afdc38183e2147123330b74f Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Fri, 18 Aug 2023 11:42:55 +0800 Subject: [PATCH] Support tiled encode/decode for `AutoencoderTiny` (#4627) * Impl tae slicing and tiling * add tae tiling test * add parameterized test * formatted code * fix failed test * style docs --- src/diffusers/models/autoencoder_tiny.py | 153 ++++++++++++++++++++++- tests/models/test_models_vae.py | 17 +++ 2 files changed, 168 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py index f6075ef4ce..fceaf0f130 100644 --- a/src/diffusers/models/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoder_tiny.py @@ -137,6 +137,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): self.latent_shift = latent_shift self.scaling_factor = scaling_factor + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.spatial_scale_factor = 2**out_channels + self.tile_overlap_factor = 0.125 + self.tile_sample_min_size = 512 + self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (EncoderTiny, DecoderTiny)): module.gradient_checkpointing = value @@ -149,11 +158,147 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): """[0, 1] -> raw latents""" return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a + plain `tuple` is returned. + """ + # scale of encoder output relative to input + sf = self.spatial_scale_factor + tile_size = self.tile_sample_min_size + + # number of pixels to blend and to traverse between tile + blend_size = int(tile_size * self.tile_overlap_factor) + traverse_size = tile_size - blend_size + + # tiles index (up/left) + ti = range(0, x.shape[-2], traverse_size) + tj = range(0, x.shape[-1], traverse_size) + + # mask for blending + blend_masks = torch.stack( + torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij") + ) + blend_masks = blend_masks.clamp(0, 1).to(x.device) + + # output array + out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device) + for i in ti: + for j in tj: + tile_in = x[..., i : i + tile_size, j : j + tile_size] + # tile result + tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf] + tile = self.encoder(tile_in) + h, w = tile.shape[-2], tile.shape[-1] + # blend tile result into output + blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] + blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] + blend_mask = blend_mask_i * blend_mask_j + tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w] + tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) + return out + + def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # scale of decoder output relative to input + sf = self.spatial_scale_factor + tile_size = self.tile_latent_min_size + + # number of pixels to blend and to traverse between tiles + blend_size = int(tile_size * self.tile_overlap_factor) + traverse_size = tile_size - blend_size + + # tiles index (up/left) + ti = range(0, x.shape[-2], traverse_size) + tj = range(0, x.shape[-1], traverse_size) + + # mask for blending + blend_masks = torch.stack( + torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij") + ) + blend_masks = blend_masks.clamp(0, 1).to(x.device) + + # output array + out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device) + for i in ti: + for j in tj: + tile_in = x[..., i : i + tile_size, j : j + tile_size] + # tile result + tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf] + tile = self.decoder(tile_in) + h, w = tile.shape[-2], tile.shape[-1] + # blend tile result into output + blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] + blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] + blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w] + tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) + return out + @apply_forward_hook def encode( self, x: torch.FloatTensor, return_dict: bool = True ) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]: - output = self.encoder(x) + if self.use_slicing and x.shape[0] > 1: + output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)] + output = torch.cat(output) + else: + output = self._tiled_encode(x) if self.use_tiling else self.encoder(x) if not return_dict: return (output,) @@ -162,7 +307,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): @apply_forward_hook def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: - output = self.decoder(x) + if self.use_slicing and x.shape[0] > 1: + output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] + output = torch.cat(output) + else: + output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) # Refer to the following discussion to know why this is needed. # https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854 output = output.mul_(2).sub_(1) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 65c7d49f75..13c8652bf0 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -285,6 +285,23 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase): model.to(torch_device).eval() return model + @parameterized.expand( + [ + [(1, 4, 73, 97), (1, 3, 584, 776)], + [(1, 4, 97, 73), (1, 3, 776, 584)], + [(1, 4, 49, 65), (1, 3, 392, 520)], + [(1, 4, 65, 49), (1, 3, 520, 392)], + [(1, 4, 49, 49), (1, 3, 392, 392)], + ] + ) + def test_tae_tiling(self, in_shape, out_shape): + model = self.get_sd_vae_model() + model.enable_tiling() + with torch.no_grad(): + zeros = torch.zeros(in_shape).to(torch_device) + dec = model.decode(zeros).sample + assert dec.shape == out_shape + def test_stable_diffusion(self): model = self.get_sd_vae_model() image = self.get_sd_image(seed=33)