mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
from functools import wraps
|
|
import torch
|
|
import diffusers
|
|
from modules import devices # pylint: disable=ungrouped-imports
|
|
|
|
|
|
def model_hijack(): # a111 compatibility item
|
|
pass
|
|
|
|
|
|
def register_buffer(self, name, attr):
|
|
"""
|
|
Fix register buffer bug for Mac OS.
|
|
"""
|
|
|
|
if type(attr) == torch.Tensor:
|
|
if attr.device != devices.device:
|
|
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
|
|
|
setattr(self, name, attr)
|
|
|
|
|
|
# Upcast BF16 to FP32
|
|
original_fft_fftn = torch.fft.fftn
|
|
@wraps(torch.fft.fftn)
|
|
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None): # pylint: disable=redefined-builtin
|
|
return_dtype = input.dtype
|
|
if input.dtype == torch.bfloat16:
|
|
input = input.to(dtype=torch.float32)
|
|
return original_fft_fftn(input, s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
|
|
|
|
|
# Upcast BF16 to FP32
|
|
original_fft_ifftn = torch.fft.ifftn
|
|
@wraps(torch.fft.ifftn)
|
|
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None): # pylint: disable=redefined-builtin
|
|
return_dtype = input.dtype
|
|
if input.dtype == torch.bfloat16:
|
|
input = input.to(dtype=torch.float32)
|
|
return original_fft_ifftn(input, s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
|
|
|
|
|
# Diffusers FreeU
|
|
# Diffusers is imported before sd_hijacks so fourier_filter needs hijacking too
|
|
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
|
|
@wraps(diffusers.utils.torch_utils.fourier_filter)
|
|
def fourier_filter(x_in, threshold, scale):
|
|
return_dtype = x_in.dtype
|
|
if x_in.dtype == torch.bfloat16:
|
|
x_in = x_in.to(dtype=torch.float32)
|
|
return original_fourier_filter(x_in, threshold, scale).to(dtype=return_dtype)
|
|
|
|
|
|
# IPEX always upcasts
|
|
if devices.backend != "ipex":
|
|
torch.fft.fftn = fft_fftn
|
|
torch.fft.ifftn = fft_ifftn
|
|
diffusers.utils.torch_utils.fourier_filter = fourier_filter
|
|
|
|
|
|
# Fix "torch is not defined" error on img2img pipelines when torch.compile for vae.encode is enabled:
|
|
# disable_compile for AutoencoderKLOutput is the only change
|
|
if torch.__version__.startswith("2.6"):
|
|
from dataclasses import dataclass
|
|
from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports
|
|
import diffusers.models.autoencoders.autoencoder_kl # pylint: disable=ungrouped-imports
|
|
|
|
@dataclass
|
|
@disable_compile
|
|
class AutoencoderKLOutput(diffusers.utils.BaseOutput):
|
|
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
|
diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput
|