From 139f707e6edc0cf54531f72ca731d674ca5d7760 Mon Sep 17 00:00:00 2001 From: lawfordp2017 Date: Tue, 19 Mar 2024 04:47:44 -0600 Subject: [PATCH] Correction for non-integral image resolutions with quantizations other than float32 (#7356) * Correction for non-integral image resolutions with quantizations other than float32. * Support for training, and use of diffusers-style casting. --- src/diffusers/models/unets/unet_stable_cascade.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 9f81e50241..197ddeec75 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -521,9 +521,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin): if isinstance(block, SDCascadeResBlock): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + orig_type = x.dtype x = torch.nn.functional.interpolate( x.float(), skip.shape[-2:], mode="bilinear", align_corners=True ) + x = x.to(orig_type) x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, skip, use_reentrant=False ) @@ -547,9 +549,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin): if isinstance(block, SDCascadeResBlock): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + orig_type = x.dtype x = torch.nn.functional.interpolate( x.float(), skip.shape[-2:], mode="bilinear", align_corners=True ) + x = x.to(orig_type) x = block(x, skip) elif isinstance(block, SDCascadeAttnBlock): x = block(x, clip)