From 58358c2d003f7a25120aea9c4545571d6feefe21 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 13 May 2025 01:57:47 +0200 Subject: [PATCH] decode block, if skip decoding do not need to update latent --- .../stable_diffusion_xl/after_denoise.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py index 9746832506..6ce59b5c35 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py @@ -98,16 +98,17 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): block_state = self.get_block_state(state) if not block_state.output_type == "latent": + latents = block_state.latents # make sure the VAE is in float32 mode, as it overflows in float16 block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast if block_state.needs_upcasting: self.upcast_vae(components) - block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) - elif block_state.latents.dtype != components.vae.dtype: + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - components.vae = components.vae.to(block_state.latents.dtype) + components.vae = components.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None @@ -119,16 +120,16 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): ) if block_state.has_latents_mean and block_state.has_latents_std: block_state.latents_mean = ( - torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) block_state.latents_std = ( - torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype) + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean else: - block_state.latents = block_state.latents / components.vae.config.scaling_factor + latents = latents / components.vae.config.scaling_factor - block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0] + block_state.images = components.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if block_state.needs_upcasting: @@ -186,6 +187,7 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): return components, state +# YiYi TODO: remove this, we don't need this in modular class StableDiffusionXLOutputStep(PipelineBlock): model_name = "stable-diffusion-xl"