diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py index fceaf0f130..ad36b7a2ce 100644 --- a/src/diffusers/models/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoder_tiny.py @@ -312,9 +312,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): output = torch.cat(output) else: output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) - # Refer to the following discussion to know why this is needed. - # https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854 - output = output.mul_(2).sub_(1) if not return_dict: return (output,) @@ -333,8 +330,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ enc = self.encode(sample).latents + + # scale latents to be in [0, 1], then quantize latents to a byte tensor, + # as if we were storing the latents in an RGBA uint8 image. scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() - unscaled_enc = self.unscale_latents(scaled_enc) + + # unquantize latents back into [0, 1], then unscale latents back to their original range, + # as if we were loading the latents from an RGBA uint8 image. + unscaled_enc = self.unscale_latents(scaled_enc / 255.0) + dec = self.decode(unscaled_enc) if not return_dict: diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index bd9562909d..d3c512b073 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -732,7 +732,8 @@ class EncoderTiny(nn.Module): x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) else: - x = self.layers(x) + # scale image from [-1, 1] to [0, 1] to match TAESD convention + x = self.layers(x.add(1).div(2)) return x @@ -790,4 +791,5 @@ class DecoderTiny(nn.Module): else: x = self.layers(x) - return x + # scale image from [0, 1] to [-1, 1] to match diffusers convention + return x.mul(2).sub(1) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 13c8652bf0..fe38b4fc21 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -312,10 +312,32 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase): assert sample.shape == image.shape output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() - expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604]) + expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382]) assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) + @parameterized.expand([(True,), (False,)]) + def test_tae_roundtrip(self, enable_tiling): + # load the autoencoder + model = self.get_sd_vae_model() + if enable_tiling: + model.enable_tiling() + + # make a black image with a white square in the middle, + # which is large enough to split across multiple tiles + image = -torch.ones(1, 3, 1024, 1024, device=torch_device) + image[..., 256:768, 256:768] = 1.0 + + # round-trip the image through the autoencoder + with torch.no_grad(): + sample = model(image).sample + + # the autoencoder reconstruction should match original image, sorta + def downscale(x): + return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor) + + assert torch_all_close(downscale(sample), downscale(image), atol=0.125) + @slow class AutoencoderKLIntegrationTests(unittest.TestCase):