mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
feat: allow offset_noise in dreambooth training example (#2826)
This commit is contained in:
@@ -417,6 +417,16 @@ def parse_args(input_args=None):
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--offset_noise",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Fine-tuning against a modified noise"
|
||||
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
|
||||
),
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
@@ -943,7 +953,12 @@ def main(args):
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
if args.offset_noise:
|
||||
noise = torch.randn_like(latents) + 0.1 * torch.randn(
|
||||
latents.shape[0], latents.shape[1], 1, 1, device=latents.device
|
||||
)
|
||||
else:
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
|
||||
Reference in New Issue
Block a user