1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/apg/__init__.py
2024-10-18 01:12:28 +03:00

59 lines
1.9 KiB
Python

# copied from paper: <https://arxiv.org/pdf/2410.02416>
import torch
import diffusers
from .pipeline_stable_diffision_xl_apg import StableDiffusionXLPipelineAPG
from .pipeline_stable_cascade_prior_apg import StableCascadePriorPipelineAPG
from .pipeline_stable_diffusion_apg import StableDiffusionPipelineAPG
class MomentumBuffer:
def __init__(self, momentum_val: float):
self.momentum = momentum_val
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
eta = 0
momentum = 0
threshold = 0
buffer: MomentumBuffer = None
orig_pipe: diffusers.DiffusionPipeline = None
def project(
v0: torch.Tensor, # [B, C, H, W]
v1: torch.Tensor, # [B, C, H, W]
):
device = v0.device
dtype = v0.dtype
if device.type == "xpu":
v0, v1 = v0.to("cpu"), v1.to("cpu")
v0, v1 = v0.double(), v1.double()
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(device, dtype=dtype), v0_orthogonal.to(device, dtype=dtype)
def normalized_guidance(
pred_cond: torch.Tensor, # [B, C, H, W]
pred_uncond: torch.Tensor, # [B, C, H, W]
guidance_scale: float,
):
diff = pred_cond - pred_uncond
if buffer is not None:
buffer.update(diff)
diff = buffer.running_average
if threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, threshold / diff_norm)
diff = diff * scale_factor
diff_parallel, diff_orthogonal = project(diff, pred_cond)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided