mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Added input pretubation (#3292)
* Added input pretubation * Fixed spelling
This commit is contained in:
@@ -112,6 +112,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
@@ -801,7 +804,8 @@ def main():
|
||||
noise += args.noise_offset * torch.randn(
|
||||
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
||||
)
|
||||
|
||||
if args.input_pertubation:
|
||||
new_noise = noise + args.input_pertubation * torch.randn_like(noise)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
@@ -809,7 +813,10 @@ def main():
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
if args.input_pertubation:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
|
||||
else:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
Reference in New Issue
Block a user