mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Examples] fix unconditioning generation training example for mixed-precision training (#5407)
* fix: unconditional generation example * fix: float in loss. * apply styling.
This commit is contained in:
@@ -413,6 +413,14 @@ def main(args):
|
||||
model_config=model.config,
|
||||
)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
args.mixed_precision = accelerator.mixed_precision
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
args.mixed_precision = accelerator.mixed_precision
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
@@ -559,11 +567,9 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
clean_images = batch["input"]
|
||||
clean_images = batch["input"].to(weight_dtype)
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(
|
||||
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
|
||||
).to(clean_images.device)
|
||||
noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
|
||||
bsz = clean_images.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -579,15 +585,14 @@ def main(args):
|
||||
model_output = model(noisy_images, timesteps).sample
|
||||
|
||||
if args.prediction_type == "epsilon":
|
||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||
loss = F.mse_loss(model_output.float(), noise.float()) # this could have different weights!
|
||||
elif args.prediction_type == "sample":
|
||||
alpha_t = _extract_into_tensor(
|
||||
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
||||
)
|
||||
snr_weights = alpha_t / (1 - alpha_t)
|
||||
loss = snr_weights * F.mse_loss(
|
||||
model_output, clean_images, reduction="none"
|
||||
) # use SNR weighting from distillation paper
|
||||
# use SNR weighting from distillation paper
|
||||
loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction="none")
|
||||
loss = loss.mean()
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
|
||||
|
||||
Reference in New Issue
Block a user