mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support tiled encode/decode for AutoencoderTiny (#4627)
* Impl tae slicing and tiling * add tae tiling test * add parameterized test * formatted code * fix failed test * style docs
This commit is contained in:
@@ -137,6 +137,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self.latent_shift = latent_shift
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.spatial_scale_factor = 2**out_channels
|
||||
self.tile_overlap_factor = 0.125
|
||||
self.tile_sample_min_size = 512
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
||||
module.gradient_checkpointing = value
|
||||
@@ -149,11 +158,147 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
|
||||
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
def disable_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.enable_tiling(False)
|
||||
|
||||
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
|
||||
plain `tuple` is returned.
|
||||
"""
|
||||
# scale of encoder output relative to input
|
||||
sf = self.spatial_scale_factor
|
||||
tile_size = self.tile_sample_min_size
|
||||
|
||||
# number of pixels to blend and to traverse between tile
|
||||
blend_size = int(tile_size * self.tile_overlap_factor)
|
||||
traverse_size = tile_size - blend_size
|
||||
|
||||
# tiles index (up/left)
|
||||
ti = range(0, x.shape[-2], traverse_size)
|
||||
tj = range(0, x.shape[-1], traverse_size)
|
||||
|
||||
# mask for blending
|
||||
blend_masks = torch.stack(
|
||||
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
|
||||
)
|
||||
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
||||
|
||||
# output array
|
||||
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
|
||||
for i in ti:
|
||||
for j in tj:
|
||||
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
||||
# tile result
|
||||
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
|
||||
tile = self.encoder(tile_in)
|
||||
h, w = tile.shape[-2], tile.shape[-1]
|
||||
# blend tile result into output
|
||||
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
||||
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
||||
blend_mask = blend_mask_i * blend_mask_j
|
||||
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
|
||||
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
||||
return out
|
||||
|
||||
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
# scale of decoder output relative to input
|
||||
sf = self.spatial_scale_factor
|
||||
tile_size = self.tile_latent_min_size
|
||||
|
||||
# number of pixels to blend and to traverse between tiles
|
||||
blend_size = int(tile_size * self.tile_overlap_factor)
|
||||
traverse_size = tile_size - blend_size
|
||||
|
||||
# tiles index (up/left)
|
||||
ti = range(0, x.shape[-2], traverse_size)
|
||||
tj = range(0, x.shape[-1], traverse_size)
|
||||
|
||||
# mask for blending
|
||||
blend_masks = torch.stack(
|
||||
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
|
||||
)
|
||||
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
||||
|
||||
# output array
|
||||
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
|
||||
for i in ti:
|
||||
for j in tj:
|
||||
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
||||
# tile result
|
||||
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
|
||||
tile = self.decoder(tile_in)
|
||||
h, w = tile.shape[-2], tile.shape[-1]
|
||||
# blend tile result into output
|
||||
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
||||
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
||||
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
|
||||
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
||||
return out
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
|
||||
output = self.encoder(x)
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@@ -162,7 +307,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
output = self.decoder(x)
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
|
||||
# Refer to the following discussion to know why this is needed.
|
||||
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
|
||||
output = output.mul_(2).sub_(1)
|
||||
|
||||
@@ -285,6 +285,23 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
model.to(torch_device).eval()
|
||||
return model
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
[(1, 4, 73, 97), (1, 3, 584, 776)],
|
||||
[(1, 4, 97, 73), (1, 3, 776, 584)],
|
||||
[(1, 4, 49, 65), (1, 3, 392, 520)],
|
||||
[(1, 4, 65, 49), (1, 3, 520, 392)],
|
||||
[(1, 4, 49, 49), (1, 3, 392, 392)],
|
||||
]
|
||||
)
|
||||
def test_tae_tiling(self, in_shape, out_shape):
|
||||
model = self.get_sd_vae_model()
|
||||
model.enable_tiling()
|
||||
with torch.no_grad():
|
||||
zeros = torch.zeros(in_shape).to(torch_device)
|
||||
dec = model.decode(zeros).sample
|
||||
assert dec.shape == out_shape
|
||||
|
||||
def test_stable_diffusion(self):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed=33)
|
||||
|
||||
Reference in New Issue
Block a user