From 443aa14e415533031aa8b9761d1808ec68acda1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Tolga=20Cang=C3=B6z?= <46008593+standardAI@users.noreply.github.com> Date: Tue, 26 Mar 2024 15:29:08 +0300 Subject: [PATCH] Fix Tiling in `ConsistencyDecoderVAE` (#7290) * Fix typos * Add docstring to `decode` method in `ConsistencyDecoderVAE` * Fix tiling * Enable tiled VAE decoding with customizable tile sample size and overlap factor * Revert "Enable tiled VAE decoding with customizable tile sample size and overlap factor" This reverts commit 181049675e83cea7b33ae2bbeba2aff7ae1b1761. * Add VAE tiling test for `ConsistencyDecoderVAE` --------- Co-authored-by: Sayak Paul --- .../autoencoders/consistency_decoder_vae.py | 33 +++++++++++++++++-- tests/models/autoencoders/test_models_vae.py | 30 +++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index 72c512da98..7287cbd43f 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16 ... ).to("cuda") - >>> pipe("horse", generator=torch.manual_seed(0)).images + >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0] + >>> image ``` """ @@ -72,6 +73,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): self, scaling_factor: float = 0.18215, latent_channels: int = 4, + sample_size: int = 32, encoder_act_fn: str = "silu", encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), encoder_double_z: bool = True, @@ -153,6 +155,16 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): self.use_slicing = False self.use_tiling = False + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling def enable_tiling(self, use_tiling: bool = True): r""" @@ -272,7 +284,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): Args: x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain + Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain tuple. Returns: @@ -305,6 +317,19 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): return_dict: bool = True, num_inference_steps: int = 2, ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: + """ + Decodes the input latent vector `z` using the consistency decoder VAE model. + + Args: + z (torch.FloatTensor): The input latent vector. + generator (Optional[torch.Generator]): The random number generator. Default is None. + return_dict (bool): Whether to return the output as a dictionary. Default is True. + num_inference_steps (int): The number of inference steps. Default is 2. + + Returns: + Union[DecoderOutput, Tuple[torch.FloatTensor]]: The decoded output. + + """ z = (z * self.config.scaling_factor - self.means) / self.stds scale_factor = 2 ** (len(self.config.block_out_channels) - 1) @@ -345,7 +370,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b - def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput: + def tiled_encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[ConsistencyDecoderVAEOutput, Tuple]: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 41db7fc2cf..ef9dca9fef 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -1116,3 +1116,33 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): ) assert torch_all_close(actual_output, expected_output, atol=5e-3) + + def test_vae_tiling(self): + vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + out_1 = pipe( + "horse", + num_inference_steps=2, + output_type="pt", + generator=torch.Generator("cpu").manual_seed(0), + ).images[0] + + # make sure tiled vae decode yields the same result + pipe.enable_vae_tiling() + out_2 = pipe( + "horse", + num_inference_steps=2, + output_type="pt", + generator=torch.Generator("cpu").manual_seed(0), + ).images[0] + + assert torch_all_close(out_1, out_2, atol=5e-3) + + # test that tiled decode works with various shapes + shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)] + for shape in shapes: + image = torch.zeros(shape, device=torch_device) + pipe.vae.decode(image)