1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allegro VAE fix (#9811)

fix
This commit is contained in:
Aryan
2024-10-30 18:04:15 +05:30
committed by GitHub
parent 0d1d267b12
commit 9a92b8177c

View File

@@ -1091,8 +1091,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
encoder_local_batch_size: int = 2,
decoder_local_batch_size: int = 2,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
@@ -1103,18 +1101,14 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
PyTorch random number generator.
encoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the encoder's batch inference.
decoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the decoder's batch inference.
"""
x = sample
posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
dec = self.decode(z).sample
if not return_dict:
return (dec,)