mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
SDXL Turbo support and example launch (#6473)
* support and example launch for sdxl turbo * White space fixes * Trailing whitespace character * ruff format * fix guidance_scale and steps for turbo mode --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Radames Ajna <radamajna@gmail.com>
This commit is contained in:
@@ -61,6 +61,34 @@ accelerate launch train_diffusion_dpo_sdxl.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## SDXL Turbo training command
|
||||
|
||||
```bash
|
||||
accelerate launch train_diffusion_dpo_sdxl.py \
|
||||
--pretrained_model_name_or_path=stabilityai/sdxl-turbo \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir="diffusion-sdxl-turbo-dpo" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=kashif/pickascore \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=2000 \
|
||||
--checkpointing_steps=500 \
|
||||
--run_validation --validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--is_turbo --resolution 512 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
|
||||
|
||||
@@ -118,9 +118,16 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
|
||||
images = []
|
||||
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
|
||||
|
||||
guidance_scale = 5.0
|
||||
num_inference_steps = 25
|
||||
if args.is_turbo:
|
||||
guidance_scale = 0.0
|
||||
num_inference_steps = 4
|
||||
for prompt in VALIDATION_PROMPTS:
|
||||
with context:
|
||||
image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
image = pipeline(
|
||||
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
tracker_key = "test" if is_final_validation else "validation"
|
||||
@@ -141,7 +148,10 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
|
||||
if is_final_validation:
|
||||
pipeline.disable_lora()
|
||||
no_lora_images = [
|
||||
pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
|
||||
pipeline(
|
||||
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
||||
).images[0]
|
||||
for prompt in VALIDATION_PROMPTS
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -423,6 +433,11 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_turbo",
|
||||
action="store_true",
|
||||
help=("Use if tuning SDXL Turbo instead of SDXL"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -444,6 +459,9 @@ def parse_args(input_args=None):
|
||||
if args.dataset_name is None:
|
||||
raise ValueError("Must provide a `dataset_name`.")
|
||||
|
||||
if args.is_turbo:
|
||||
assert "turbo" in args.pretrained_model_name_or_path
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -560,6 +578,36 @@ def main(args):
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
def enforce_zero_terminal_snr(scheduler):
|
||||
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
|
||||
# Original implementation https://arxiv.org/pdf/2305.08891.pdf
|
||||
# Turbo needs zero terminal SNR
|
||||
# Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1 - scheduler.betas
|
||||
alphas_bar = alphas.cumprod(0)
|
||||
alphas_bar_sqrt = alphas_bar.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
# Shift so last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
# Scale so first timestep is back to old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
alphas_bar = alphas_bar_sqrt**2
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
scheduler.alphas_cumprod = alphas_cumprod
|
||||
return
|
||||
|
||||
if args.is_turbo:
|
||||
enforce_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -909,6 +957,10 @@ def main(args):
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
|
||||
).repeat(2)
|
||||
if args.is_turbo:
|
||||
# Learn a 4 timestep schedule
|
||||
timesteps_0_to_3 = timesteps % 4
|
||||
timesteps = 250 * timesteps_0_to_3 + 249
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
||||
Reference in New Issue
Block a user