# copied from paper: import torch import diffusers from .pipeline_stable_diffision_xl_apg import StableDiffusionXLPipelineAPG from .pipeline_stable_cascade_prior_apg import StableCascadePriorPipelineAPG from .pipeline_stable_diffusion_apg import StableDiffusionPipelineAPG class MomentumBuffer: def __init__(self, momentum_val: float): self.momentum = momentum_val self.running_average = 0 def update(self, update_value: torch.Tensor): new_average = self.momentum * self.running_average self.running_average = update_value + new_average eta = 0 momentum = 0 threshold = 0 buffer: MomentumBuffer = None orig_pipe: diffusers.DiffusionPipeline = None def project( v0: torch.Tensor, # [B, C, H, W] v1: torch.Tensor, # [B, C, H, W] ): device = v0.device dtype = v0.dtype if device.type == "xpu": v0, v1 = v0.to("cpu"), v1.to("cpu") v0, v1 = v0.double(), v1.double() v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel return v0_parallel.to(device, dtype=dtype), v0_orthogonal.to(device, dtype=dtype) def normalized_guidance( pred_cond: torch.Tensor, # [B, C, H, W] pred_uncond: torch.Tensor, # [B, C, H, W] guidance_scale: float, ): diff = pred_cond - pred_uncond if buffer is not None: buffer.update(diff) diff = buffer.running_average if threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) scale_factor = torch.minimum(ones, threshold / diff_norm) diff = diff * scale_factor diff_parallel, diff_orthogonal = project(diff, pred_cond) normalized_update = diff_orthogonal + eta * diff_parallel pred_guided = pred_cond + (guidance_scale - 1) * normalized_update return pred_guided