1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix dreambooth loss type with prior_preservation and fp16 (#826)

Fix dreambooth loss type with prior preservation
This commit is contained in:
Anton Lozhkov
2022-10-13 15:41:19 +02:00
committed by GitHub
parent 0a09af2f0a
commit e001fededf

View File

@@ -544,7 +544,7 @@ def main():
noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")