diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 9f81e50241..197ddeec75 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -521,9 +521,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin): if isinstance(block, SDCascadeResBlock): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + orig_type = x.dtype x = torch.nn.functional.interpolate( x.float(), skip.shape[-2:], mode="bilinear", align_corners=True ) + x = x.to(orig_type) x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, skip, use_reentrant=False ) @@ -547,9 +549,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin): if isinstance(block, SDCascadeResBlock): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + orig_type = x.dtype x = torch.nn.functional.interpolate( x.float(), skip.shape[-2:], mode="bilinear", align_corners=True ) + x = x.to(orig_type) x = block(x, skip) elif isinstance(block, SDCascadeAttnBlock): x = block(x, clip)