1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sd_hijack.py
Vladimir Mandic edefa8cb14 improve extensions compatibility
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-05 17:42:13 -04:00

73 lines
2.5 KiB
Python

from functools import wraps
import torch
import diffusers
from modules import devices # pylint: disable=ungrouped-imports
def model_hijack(): # a111 compatibility item
pass
def register_buffer(self, name, attr):
"""
Fix register buffer bug for Mac OS.
"""
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)
# Upcast BF16 to FP32
original_fft_fftn = torch.fft.fftn
@wraps(torch.fft.fftn)
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None): # pylint: disable=redefined-builtin
return_dtype = input.dtype
if input.dtype == torch.bfloat16:
input = input.to(dtype=torch.float32)
return original_fft_fftn(input, s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
# Upcast BF16 to FP32
original_fft_ifftn = torch.fft.ifftn
@wraps(torch.fft.ifftn)
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None): # pylint: disable=redefined-builtin
return_dtype = input.dtype
if input.dtype == torch.bfloat16:
input = input.to(dtype=torch.float32)
return original_fft_ifftn(input, s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
# Diffusers FreeU
# Diffusers is imported before sd_hijacks so fourier_filter needs hijacking too
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
@wraps(diffusers.utils.torch_utils.fourier_filter)
def fourier_filter(x_in, threshold, scale):
return_dtype = x_in.dtype
if x_in.dtype == torch.bfloat16:
x_in = x_in.to(dtype=torch.float32)
return original_fourier_filter(x_in, threshold, scale).to(dtype=return_dtype)
# IPEX always upcasts
if devices.backend != "ipex":
torch.fft.fftn = fft_fftn
torch.fft.ifftn = fft_ifftn
diffusers.utils.torch_utils.fourier_filter = fourier_filter
# Fix "torch is not defined" error on img2img pipelines when torch.compile for vae.encode is enabled:
# disable_compile for AutoencoderKLOutput is the only change
if torch.__version__.startswith("2.6"):
from dataclasses import dataclass
from torch.compiler import disable as disable_compile # pylint: disable=ungrouped-imports
import diffusers.models.autoencoders.autoencoder_kl # pylint: disable=ungrouped-imports
@dataclass
@disable_compile
class AutoencoderKLOutput(diffusers.utils.BaseOutput):
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput = AutoencoderKLOutput