From b218062fed08d6cc164206d6cb852b2b7b00847a Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Sat, 25 Mar 2023 18:12:22 -0700 Subject: [PATCH] Update Pix2PixZero Auto-correlation Loss --- .../pipeline_stable_diffusion_pix2pix_zero.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 4c2dbe6ff8..72336f72a1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -750,23 +750,18 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ) def auto_corr_loss(self, hidden_states, generator=None): - batch_size, channel, height, width = hidden_states.shape - if batch_size > 1: - raise ValueError("Only batch_size 1 is supported for now") - - hidden_states = hidden_states.squeeze(0) - # hidden_states must be shape [C,H,W] now reg_loss = 0.0 for i in range(hidden_states.shape[0]): - noise = hidden_states[i][None, None, :, :] - while True: - roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 - if noise.shape[2] <= 8: - break - noise = F.avg_pool2d(noise, kernel_size=2) + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(self, hidden_states):