mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix dreambooth loss type with prior_preservation and fp16 (#826)
Fix dreambooth loss type with prior preservation
This commit is contained in:
@@ -544,7 +544,7 @@ def main():
|
||||
noise, noise_prior = torch.chunk(noise, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
|
||||
|
||||
Reference in New Issue
Block a user