From 9a92b8177cb3f8bf4b095fff55da3b45a3607960 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 30 Oct 2024 18:04:15 +0530 Subject: [PATCH] Allegro VAE fix (#9811) fix --- .../models/autoencoders/autoencoder_kl_allegro.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 4836de7e16..922fd15c08 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -1091,8 +1091,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin): sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, - encoder_local_batch_size: int = 2, - decoder_local_batch_size: int = 2, ) -> Union[DecoderOutput, torch.Tensor]: r""" Args: @@ -1103,18 +1101,14 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. generator (`torch.Generator`, *optional*): PyTorch random number generator. - encoder_local_batch_size (`int`, *optional*, defaults to 2): - Local batch size for the encoder's batch inference. - decoder_local_batch_size (`int`, *optional*, defaults to 2): - Local batch size for the decoder's batch inference. """ x = sample - posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist + posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample + dec = self.decode(z).sample if not return_dict: return (dec,)