1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

decode block, if skip decoding do not need to update latent

This commit is contained in:
yiyixuxu
2025-05-13 01:57:47 +02:00
parent 5cde77f915
commit 58358c2d00

View File

@@ -98,16 +98,17 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
block_state = self.get_block_state(state)
if not block_state.output_type == "latent":
latents = block_state.latents
# make sure the VAE is in float32 mode, as it overflows in float16
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
if block_state.needs_upcasting:
self.upcast_vae(components)
block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif block_state.latents.dtype != components.vae.dtype:
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != components.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
components.vae = components.vae.to(block_state.latents.dtype)
components.vae = components.vae.to(latents.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
@@ -119,16 +120,16 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
)
if block_state.has_latents_mean and block_state.has_latents_std:
block_state.latents_mean = (
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype)
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
block_state.latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype)
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
else:
block_state.latents = block_state.latents / components.vae.config.scaling_factor
latents = latents / components.vae.config.scaling_factor
block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0]
block_state.images = components.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if block_state.needs_upcasting:
@@ -186,6 +187,7 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
return components, state
# YiYi TODO: remove this, we don't need this in modular
class StableDiffusionXLOutputStep(PipelineBlock):
model_name = "stable-diffusion-xl"