1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update Pix2PixZero Auto-correlation Loss

This commit is contained in:
Clarence Chen
2023-03-25 18:12:22 -07:00
parent b94880e536
commit b218062fed

View File

@@ -750,23 +750,18 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
)
def auto_corr_loss(self, hidden_states, generator=None):
batch_size, channel, height, width = hidden_states.shape
if batch_size > 1:
raise ValueError("Only batch_size 1 is supported for now")
hidden_states = hidden_states.squeeze(0)
# hidden_states must be shape [C,H,W] now
reg_loss = 0.0
for i in range(hidden_states.shape[0]):
noise = hidden_states[i][None, None, :, :]
while True:
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
for j in range(hidden_states.shape[1]):
noise = hidden_states[i : i + 1, j : j + 1, :, :]
while True:
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
return reg_loss
def kl_divergence(self, hidden_states):