mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
decode block, if skip decoding do not need to update latent
This commit is contained in:
@@ -98,16 +98,17 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if not block_state.output_type == "latent":
|
||||
latents = block_state.latents
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
|
||||
|
||||
if block_state.needs_upcasting:
|
||||
self.upcast_vae(components)
|
||||
block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
|
||||
elif block_state.latents.dtype != components.vae.dtype:
|
||||
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
|
||||
elif latents.dtype != components.vae.dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
components.vae = components.vae.to(block_state.latents.dtype)
|
||||
components.vae = components.vae.to(latents.dtype)
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
@@ -119,16 +120,16 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
|
||||
)
|
||||
if block_state.has_latents_mean and block_state.has_latents_std:
|
||||
block_state.latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype)
|
||||
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
block_state.latents_std = (
|
||||
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype)
|
||||
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
|
||||
latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
|
||||
else:
|
||||
block_state.latents = block_state.latents / components.vae.config.scaling_factor
|
||||
latents = latents / components.vae.config.scaling_factor
|
||||
|
||||
block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0]
|
||||
block_state.images = components.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if block_state.needs_upcasting:
|
||||
@@ -186,6 +187,7 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
# YiYi TODO: remove this, we don't need this in modular
|
||||
class StableDiffusionXLOutputStep(PipelineBlock):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user