mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix AutoencoderTiny encoder scaling convention (#4682)
* Fix AutoencoderTiny encoder scaling convention
* Add [-1, 1] -> [0, 1] rescaling to EncoderTiny
* Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
(i.e. immediately after the final conv, as early as possible)
* Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward
* Update AutoencoderTinyIntegrationTests to protect against scaling issues.
The new test constructs a simple image, round-trips it through AutoencoderTiny,
and confirms the decoded result is approximately equal to the source image.
This test checks behavior with and without tiling enabled.
This test will fail if new AutoencoderTiny scaling issues are introduced.
* Context: Raw TAESD weights expect images in [0, 1], but diffusers'
convention represents images with zero-centered values in [-1, 1],
so AutoencoderTiny needs to scale / unscale images at the start of
encoding and at the end of decoding in order to work with diffusers.
* Re-add existing AutoencoderTiny test, update golden values
* Add comments to AutoencoderTiny.forward
This commit is contained in:
@@ -312,9 +312,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
|
||||
# Refer to the following discussion to know why this is needed.
|
||||
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
|
||||
output = output.mul_(2).sub_(1)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@@ -333,8 +330,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
enc = self.encode(sample).latents
|
||||
|
||||
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
|
||||
# as if we were storing the latents in an RGBA uint8 image.
|
||||
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
|
||||
unscaled_enc = self.unscale_latents(scaled_enc)
|
||||
|
||||
# unquantize latents back into [0, 1], then unscale latents back to their original range,
|
||||
# as if we were loading the latents from an RGBA uint8 image.
|
||||
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
|
||||
|
||||
dec = self.decode(unscaled_enc)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -732,7 +732,8 @@ class EncoderTiny(nn.Module):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
|
||||
else:
|
||||
x = self.layers(x)
|
||||
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
||||
x = self.layers(x.add(1).div(2))
|
||||
|
||||
return x
|
||||
|
||||
@@ -790,4 +791,5 @@ class DecoderTiny(nn.Module):
|
||||
else:
|
||||
x = self.layers(x)
|
||||
|
||||
return x
|
||||
# scale image from [0, 1] to [-1, 1] to match diffusers convention
|
||||
return x.mul(2).sub(1)
|
||||
|
||||
@@ -312,10 +312,32 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604])
|
||||
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@parameterized.expand([(True,), (False,)])
|
||||
def test_tae_roundtrip(self, enable_tiling):
|
||||
# load the autoencoder
|
||||
model = self.get_sd_vae_model()
|
||||
if enable_tiling:
|
||||
model.enable_tiling()
|
||||
|
||||
# make a black image with a white square in the middle,
|
||||
# which is large enough to split across multiple tiles
|
||||
image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
|
||||
image[..., 256:768, 256:768] = 1.0
|
||||
|
||||
# round-trip the image through the autoencoder
|
||||
with torch.no_grad():
|
||||
sample = model(image).sample
|
||||
|
||||
# the autoencoder reconstruction should match original image, sorta
|
||||
def downscale(x):
|
||||
return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
|
||||
|
||||
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user