mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix DREAM training (#8302)
Co-authored-by: Jimmy <39@🇺🇸.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
committed by
GitHub
parent
86555c9f59
commit
bc108e1533
@@ -157,19 +157,19 @@ def compute_dream_and_update_latents(
|
||||
with torch.no_grad():
|
||||
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
noisy_latents, target = (None, None)
|
||||
_noisy_latents, _target = (None, None)
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
predicted_noise = pred
|
||||
delta_noise = (noise - predicted_noise).detach()
|
||||
delta_noise.mul_(dream_lambda)
|
||||
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
||||
target = target.add(delta_noise)
|
||||
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
||||
_target = target.add(delta_noise)
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
raise NotImplementedError("DREAM has not been implemented for v-prediction")
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
return noisy_latents, target
|
||||
return _noisy_latents, _target
|
||||
|
||||
|
||||
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user