1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Examples] fix unconditioning generation training example for mixed-precision training (#5407)

* fix: unconditional generation example

* fix: float in loss.

* apply styling.
This commit is contained in:
Sayak Paul
2023-10-16 14:11:35 +05:30
committed by GitHub
parent 07b297e7de
commit 93df5bb670

View File

@@ -413,6 +413,14 @@ def main(args):
model_config=model.config,
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
args.mixed_precision = accelerator.mixed_precision
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
@@ -559,11 +567,9 @@ def main(args):
progress_bar.update(1)
continue
clean_images = batch["input"]
clean_images = batch["input"].to(weight_dtype)
# Sample noise that we'll add to the images
noise = torch.randn(
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
).to(clean_images.device)
noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
@@ -579,15 +585,14 @@ def main(args):
model_output = model(noisy_images, timesteps).sample
if args.prediction_type == "epsilon":
loss = F.mse_loss(model_output, noise) # this could have different weights!
loss = F.mse_loss(model_output.float(), noise.float()) # this could have different weights!
elif args.prediction_type == "sample":
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
snr_weights = alpha_t / (1 - alpha_t)
loss = snr_weights * F.mse_loss(
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
# use SNR weighting from distillation paper
loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction="none")
loss = loss.mean()
else:
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")