From e0d8c9ef838d0a7372a4807cd978e032bd26c572 Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Thu, 23 Mar 2023 12:06:17 +0800 Subject: [PATCH] Support for Offset Noise in examples (#2753) * add noise offset * make style --- examples/text_to_image/train_text_to_image.py | 7 +++++++ examples/text_to_image/train_text_to_image_lora.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 06a847e6ca..6139a0e651 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -297,6 +297,7 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -705,6 +706,12 @@ def main(): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 43bbd8ebf4..3b54cc2866 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -333,6 +333,7 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -718,6 +719,12 @@ def main(): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)