1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

denormalize latents with the mean and std if available (#7111)

* denormalize latents with the mean and std if available

* fix denormalize

* add latent mean and std in vae config

* address sayak's comment
This commit is contained in:
Suraj Patil
2024-02-27 16:20:05 +05:30
committed by GitHub
parent ac49f97a75
commit d71ecad8cd
2 changed files with 18 additions and 1 deletions

View File

@@ -80,6 +80,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: float = True,
):
super().__init__()

View File

@@ -1313,7 +1313,22 @@ class StableDiffusionXLPipeline(
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting: