diff --git a/examples/research_projects/diffusion_dpo/README.md b/examples/research_projects/diffusion_dpo/README.md index c80d97ea26..a0b4d8ab9c 100644 --- a/examples/research_projects/diffusion_dpo/README.md +++ b/examples/research_projects/diffusion_dpo/README.md @@ -61,6 +61,34 @@ accelerate launch train_diffusion_dpo_sdxl.py \ --push_to_hub ``` +## SDXL Turbo training command + +```bash +accelerate launch train_diffusion_dpo_sdxl.py \ + --pretrained_model_name_or_path=stabilityai/sdxl-turbo \ + --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ + --output_dir="diffusion-sdxl-turbo-dpo" \ + --mixed_precision="fp16" \ + --dataset_name=kashif/pickascore \ + --train_batch_size=8 \ + --gradient_accumulation_steps=2 \ + --gradient_checkpointing \ + --use_8bit_adam \ + --rank=8 \ + --learning_rate=1e-5 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=2000 \ + --checkpointing_steps=500 \ + --run_validation --validation_steps=50 \ + --seed="0" \ + --report_to="wandb" \ + --is_turbo --resolution 512 \ + --push_to_hub +``` + + ## Acknowledgements This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/. diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 2c57a1d579..9ec3e99159 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -118,9 +118,16 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v images = [] context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast() + guidance_scale = 5.0 + num_inference_steps = 25 + if args.is_turbo: + guidance_scale = 0.0 + num_inference_steps = 4 for prompt in VALIDATION_PROMPTS: with context: - image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0] + image = pipeline( + prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator + ).images[0] images.append(image) tracker_key = "test" if is_final_validation else "validation" @@ -141,7 +148,10 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v if is_final_validation: pipeline.disable_lora() no_lora_images = [ - pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS + pipeline( + prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator + ).images[0] + for prompt in VALIDATION_PROMPTS ] for tracker in accelerator.trackers: @@ -423,6 +433,11 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--is_turbo", + action="store_true", + help=("Use if tuning SDXL Turbo instead of SDXL"), + ) parser.add_argument( "--rank", type=int, @@ -444,6 +459,9 @@ def parse_args(input_args=None): if args.dataset_name is None: raise ValueError("Must provide a `dataset_name`.") + if args.is_turbo: + assert "turbo" in args.pretrained_model_name_or_path + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -560,6 +578,36 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + def enforce_zero_terminal_snr(scheduler): + # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93 + # Original implementation https://arxiv.org/pdf/2305.08891.pdf + # Turbo needs zero terminal SNR + # Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf + # Convert betas to alphas_bar_sqrt + alphas = 1 - scheduler.betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + alphas_bar = alphas_bar_sqrt**2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + + alphas_cumprod = torch.cumprod(alphas, dim=0) + scheduler.alphas_cumprod = alphas_cumprod + return + + if args.is_turbo: + enforce_zero_terminal_snr(noise_scheduler) + text_encoder_one = text_encoder_cls_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) @@ -909,6 +957,10 @@ def main(args): timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long ).repeat(2) + if args.is_turbo: + # Learn a 4 timestep schedule + timesteps_0_to_3 = timesteps % 4 + timesteps = 250 * timesteps_0_to_3 + 249 # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process)