mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Implement SD3 loss weighting (#8528)
* Add lognorm and cosmap weighting * Implement mode sampling * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * keep timestamp sampling fully on cpu --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -1462,7 +1462,18 @@ def main(args):
|
||||
bsz = model_input.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
if args.weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif args.weighting_scheme == "mode":
|
||||
u = torch.rand(size=(bsz,), device="cpu")
|
||||
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(bsz,), device="cpu")
|
||||
|
||||
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
||||
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
@@ -1483,16 +1494,15 @@ def main(args):
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
if args.weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif args.weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
||||
weighting = torch.nn.functional.sigmoid(u)
|
||||
elif args.weighting_scheme == "mode":
|
||||
# See sec 3.1 in the SD3 paper (20).
|
||||
u = torch.rand(size=(bsz,), device=accelerator.device)
|
||||
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
elif args.weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
|
||||
# simplified flow matching aka 0-rectified flow matching loss
|
||||
# target = model_input - noise
|
||||
|
||||
@@ -1526,7 +1526,18 @@ def main(args):
|
||||
bsz = model_input.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
if args.weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif args.weighting_scheme == "mode":
|
||||
u = torch.rand(size=(bsz,), device="cpu")
|
||||
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(bsz,), device="cpu")
|
||||
|
||||
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
|
||||
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
@@ -1560,18 +1571,15 @@ def main(args):
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
if args.weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif args.weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
||||
weighting = torch.nn.functional.sigmoid(u)
|
||||
elif args.weighting_scheme == "mode":
|
||||
# See sec 3.1 in the SD3 paper (20).
|
||||
u = torch.rand(size=(bsz,), device=accelerator.device)
|
||||
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
elif args.weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
|
||||
# simplified flow matching aka 0-rectified flow matching loss
|
||||
# target = model_input - noise
|
||||
|
||||
Reference in New Issue
Block a user