From 052bf3280b9a3e83cdb06ab3a07645e23371e749 Mon Sep 17 00:00:00 2001 From: Ollin Boer Bohan Date: Tue, 22 Aug 2023 20:08:37 -0700 Subject: [PATCH] Fix AutoencoderTiny encoder scaling convention (#4682) * Fix AutoencoderTiny encoder scaling convention * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny (i.e. immediately after the final conv, as early as possible) * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward * Update AutoencoderTinyIntegrationTests to protect against scaling issues. The new test constructs a simple image, round-trips it through AutoencoderTiny, and confirms the decoded result is approximately equal to the source image. This test checks behavior with and without tiling enabled. This test will fail if new AutoencoderTiny scaling issues are introduced. * Context: Raw TAESD weights expect images in [0, 1], but diffusers' convention represents images with zero-centered values in [-1, 1], so AutoencoderTiny needs to scale / unscale images at the start of encoding and at the end of decoding in order to work with diffusers. * Re-add existing AutoencoderTiny test, update golden values * Add comments to AutoencoderTiny.forward --- src/diffusers/models/autoencoder_tiny.py | 12 ++++++++---- src/diffusers/models/vae.py | 6 ++++-- tests/models/test_models_vae.py | 24 +++++++++++++++++++++++- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py index fceaf0f130..ad36b7a2ce 100644 --- a/src/diffusers/models/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoder_tiny.py @@ -312,9 +312,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): output = torch.cat(output) else: output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) - # Refer to the following discussion to know why this is needed. - # https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854 - output = output.mul_(2).sub_(1) if not return_dict: return (output,) @@ -333,8 +330,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ enc = self.encode(sample).latents + + # scale latents to be in [0, 1], then quantize latents to a byte tensor, + # as if we were storing the latents in an RGBA uint8 image. scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() - unscaled_enc = self.unscale_latents(scaled_enc) + + # unquantize latents back into [0, 1], then unscale latents back to their original range, + # as if we were loading the latents from an RGBA uint8 image. + unscaled_enc = self.unscale_latents(scaled_enc / 255.0) + dec = self.decode(unscaled_enc) if not return_dict: diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index bd9562909d..d3c512b073 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -732,7 +732,8 @@ class EncoderTiny(nn.Module): x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) else: - x = self.layers(x) + # scale image from [-1, 1] to [0, 1] to match TAESD convention + x = self.layers(x.add(1).div(2)) return x @@ -790,4 +791,5 @@ class DecoderTiny(nn.Module): else: x = self.layers(x) - return x + # scale image from [0, 1] to [-1, 1] to match diffusers convention + return x.mul(2).sub(1) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 13c8652bf0..fe38b4fc21 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -312,10 +312,32 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase): assert sample.shape == image.shape output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604]) + expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382]) assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) + @parameterized.expand([(True,), (False,)]) + def test_tae_roundtrip(self, enable_tiling): + # load the autoencoder + model = self.get_sd_vae_model() + if enable_tiling: + model.enable_tiling() + + # make a black image with a white square in the middle, + # which is large enough to split across multiple tiles + image = -torch.ones(1, 3, 1024, 1024, device=torch_device) + image[..., 256:768, 256:768] = 1.0 + + # round-trip the image through the autoencoder + with torch.no_grad(): + sample = model(image).sample + + # the autoencoder reconstruction should match original image, sorta + def downscale(x): + return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor) + + assert torch_all_close(downscale(sample), downscale(image), atol=0.125) + @slow class AutoencoderKLIntegrationTests(unittest.TestCase):