mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Correction for non-integral image resolutions with quantizations other than float32 (#7356)
* Correction for non-integral image resolutions with quantizations other than float32. * Support for training, and use of diffusers-style casting.
This commit is contained in:
@@ -521,9 +521,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
orig_type = x.dtype
|
||||
x = torch.nn.functional.interpolate(
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = x.to(orig_type)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, skip, use_reentrant=False
|
||||
)
|
||||
@@ -547,9 +549,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
orig_type = x.dtype
|
||||
x = torch.nn.functional.interpolate(
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = x.to(orig_type)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = block(x, clip)
|
||||
|
||||
Reference in New Issue
Block a user