1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix DREAM training (#8302)

Co-authored-by: Jimmy <39@🇺🇸.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
39th president of the United States, probably
2024-06-01 03:27:57 -04:00
committed by GitHub
parent 86555c9f59
commit bc108e1533

View File

@@ -157,19 +157,19 @@ def compute_dream_and_update_latents(
with torch.no_grad():
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
noisy_latents, target = (None, None)
_noisy_latents, _target = (None, None)
if noise_scheduler.config.prediction_type == "epsilon":
predicted_noise = pred
delta_noise = (noise - predicted_noise).detach()
delta_noise.mul_(dream_lambda)
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
target = target.add(delta_noise)
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
_target = target.add(delta_noise)
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError("DREAM has not been implemented for v-prediction")
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
return noisy_latents, target
return _noisy_latents, _target
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: