mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix potential type mismatch errors in SDXL pipelines (#4796)
* Fix potential type conversion errors in SDXL pipelines * make sure vae stays in fp16 --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1184,13 +1184,19 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
@@ -772,13 +772,19 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
@@ -1183,13 +1183,19 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
@@ -751,15 +751,20 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
# post-processing
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
|
||||
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -870,13 +870,19 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
@@ -1027,13 +1027,19 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
@@ -1333,13 +1333,19 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
return StableDiffusionXLPipelineOutput(images=latents)
|
||||
|
||||
|
||||
@@ -908,13 +908,19 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
Reference in New Issue
Block a user