mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -1091,8 +1091,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
encoder_local_batch_size: int = 2,
|
||||
decoder_local_batch_size: int = 2,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -1103,18 +1101,14 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
PyTorch random number generator.
|
||||
encoder_local_batch_size (`int`, *optional*, defaults to 2):
|
||||
Local batch size for the encoder's batch inference.
|
||||
decoder_local_batch_size (`int`, *optional*, defaults to 2):
|
||||
Local batch size for the decoder's batch inference.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
Reference in New Issue
Block a user