From bc108e15333cb0e8a092647320cbb4d70d6d0f03 Mon Sep 17 00:00:00 2001 From: "39th president of the United States, probably" <110263573+AmericanPresidentJimmyCarter@users.noreply.github.com> Date: Sat, 1 Jun 2024 03:27:57 -0400 Subject: [PATCH] Fix DREAM training (#8302) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jimmy <39@🇺🇸.com> Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- src/diffusers/training_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index b617dd2eef..b2f561632d 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -157,19 +157,19 @@ def compute_dream_and_update_latents( with torch.no_grad(): pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - noisy_latents, target = (None, None) + _noisy_latents, _target = (None, None) if noise_scheduler.config.prediction_type == "epsilon": predicted_noise = pred delta_noise = (noise - predicted_noise).detach() delta_noise.mul_(dream_lambda) - noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) - target = target.add(delta_noise) + _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) + _target = target.add(delta_noise) elif noise_scheduler.config.prediction_type == "v_prediction": raise NotImplementedError("DREAM has not been implemented for v-prediction") else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - return noisy_latents, target + return _noisy_latents, _target def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: