1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix AutoencoderTiny encoder scaling convention (#4682)

* Fix AutoencoderTiny encoder scaling convention

  * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny

  * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
    (i.e. immediately after the final conv, as early as possible)

  * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward

  * Update AutoencoderTinyIntegrationTests to protect against scaling issues.
    The new test constructs a simple image, round-trips it through AutoencoderTiny,
    and confirms the decoded result is approximately equal to the source image.
    This test checks behavior with and without tiling enabled.
    This test will fail if new AutoencoderTiny scaling issues are introduced.

  * Context: Raw TAESD weights expect images in [0, 1], but diffusers'
    convention represents images with zero-centered values in [-1, 1],
    so AutoencoderTiny needs to scale / unscale images at the start of
    encoding and at the end of decoding in order to work with diffusers.

* Re-add existing AutoencoderTiny test, update golden values

* Add comments to AutoencoderTiny.forward
This commit is contained in:
Ollin Boer Bohan
2023-08-22 20:08:37 -07:00
committed by GitHub
parent 80871ac597
commit 052bf3280b
3 changed files with 35 additions and 7 deletions

View File

@@ -312,9 +312,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
output = torch.cat(output)
else:
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
# Refer to the following discussion to know why this is needed.
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
output = output.mul_(2).sub_(1)
if not return_dict:
return (output,)
@@ -333,8 +330,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
enc = self.encode(sample).latents
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
# as if we were storing the latents in an RGBA uint8 image.
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
unscaled_enc = self.unscale_latents(scaled_enc)
# unquantize latents back into [0, 1], then unscale latents back to their original range,
# as if we were loading the latents from an RGBA uint8 image.
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
dec = self.decode(unscaled_enc)
if not return_dict:

View File

@@ -732,7 +732,8 @@ class EncoderTiny(nn.Module):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
else:
x = self.layers(x)
# scale image from [-1, 1] to [0, 1] to match TAESD convention
x = self.layers(x.add(1).div(2))
return x
@@ -790,4 +791,5 @@ class DecoderTiny(nn.Module):
else:
x = self.layers(x)
return x
# scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)

View File

@@ -312,10 +312,32 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604])
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
@parameterized.expand([(True,), (False,)])
def test_tae_roundtrip(self, enable_tiling):
# load the autoencoder
model = self.get_sd_vae_model()
if enable_tiling:
model.enable_tiling()
# make a black image with a white square in the middle,
# which is large enough to split across multiple tiles
image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
image[..., 256:768, 256:768] = 1.0
# round-trip the image through the autoencoder
with torch.no_grad():
sample = model(image).sample
# the autoencoder reconstruction should match original image, sorta
def downscale(x):
return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
@slow
class AutoencoderKLIntegrationTests(unittest.TestCase):