1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Correction for non-integral image resolutions with quantizations other than float32 (#7356)

* Correction for non-integral image resolutions with quantizations other than float32.

* Support for training, and use of diffusers-style casting.
This commit is contained in:
lawfordp2017
2024-03-19 04:47:44 -06:00
committed by GitHub
parent e4546fd5bb
commit 139f707e6e

View File

@@ -521,9 +521,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
if isinstance(block, SDCascadeResBlock):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
orig_type = x.dtype
x = torch.nn.functional.interpolate(
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = x.to(orig_type)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, skip, use_reentrant=False
)
@@ -547,9 +549,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
if isinstance(block, SDCascadeResBlock):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
orig_type = x.dtype
x = torch.nn.functional.interpolate(
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = x.to(orig_type)
x = block(x, skip)
elif isinstance(block, SDCascadeAttnBlock):
x = block(x, clip)