From 7a4324cce3f84d14afe8e5cfd47fb67701ce2fd3 Mon Sep 17 00:00:00 2001 From: stano Date: Mon, 2 Oct 2023 20:17:34 +0300 Subject: [PATCH] Add a docstring for the AutoencoderKL's encode (#5239) * Add docstring for the AutoencoderKL's encode #5229 * Support Python 3.8 syntax in AutoencoderKL.decode type hints Co-authored-by: Patrick von Platen * Follow the style guidelines in AutoencoderKL's encode #5230 --------- Co-authored-by: stano <> Co-authored-by: Patrick von Platen --- src/diffusers/models/autoencoder_kl.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 7e3b925df7..80d2cccd53 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -249,7 +249,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): self.set_attn_processor(processor, _remove_lora=True) @apply_forward_hook - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): return self.tiled_encode(x, return_dict=return_dict)