diff --git a/examples/training_ddpm.py b/examples/train_ddpm.py similarity index 68% rename from examples/training_ddpm.py rename to examples/train_ddpm.py index c29472fa75..7eb0b9d34e 100644 --- a/examples/training_ddpm.py +++ b/examples/train_ddpm.py @@ -1,10 +1,10 @@ +import argparse import os import torch -import PIL.Image -import argparse import torch.nn.functional as F +import PIL.Image from accelerate import Accelerator from datasets import load_dataset from diffusers import DDPM, DDPMScheduler, UNetModel @@ -31,44 +31,40 @@ def main(args): dropout=0.0, num_res_blocks=2, resamp_with_conv=True, - resolution=64, + resolution=args.resolution, ) noise_scheduler = DDPMScheduler(timesteps=1000) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) - - num_epochs = 100 - batch_size = 16 - gradient_accumulation_steps = 1 + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) augmentations = Compose( [ - Resize(64, interpolation=InterpolationMode.BILINEAR), - RandomCrop(64), + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + RandomCrop(args.resolution), RandomHorizontalFlip(), ToTensor(), Lambda(lambda x: x * 2 - 1), ] ) - dataset = load_dataset("huggan/pokemon", split="train") + dataset = load_dataset(args.dataset, split="train") def transforms(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] return {"input": images} dataset.set_transform(transforms) - train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, - num_warmup_steps=500, - num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, + num_warmup_steps=args.warmup_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, ) model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - for epoch in range(num_epochs): + for epoch in range(args.num_epochs): model.train() with tqdm(total=len(train_dataloader), unit="ba") as pbar: pbar.set_description(f"Epoch {epoch}") @@ -84,14 +80,15 @@ def main(args): noise_samples[idx] = noise noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx]) - if step % gradient_accumulation_steps != 0: + if step % args.gradient_accumulation_steps != 0: with accelerator.no_sync(model): output = model(noisy_images, timesteps) - # predict the noise + # predict the noise residual loss = F.mse_loss(output, noise_samples) accelerator.backward(loss) else: output = model(noisy_images, timesteps) + # predict the noise residual loss = F.mse_loss(output, noise_samples) accelerator.backward(loss) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) @@ -103,13 +100,18 @@ def main(args): optimizer.step() + # Generate a sample image for visual inspection torch.distributed.barrier() if args.local_rank in [-1, 0]: model.eval() with torch.no_grad(): - pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler) - generator = torch.Generator() - generator = generator.manual_seed(0) + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler) + else: + pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler) + pipeline.save_pretrained(args.output_path) + + generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) image = pipeline(generator=generator) @@ -120,22 +122,31 @@ def main(args): image_pil = PIL.Image.fromarray(image_processed[0]) # save image - pipeline.save_pretrained("./pokemon-ddpm") - image_pil.save(f"./pokemon-ddpm/test_{epoch}.png") + test_dir = os.path.join(args.output_path, "test_samples") + os.makedirs(test_dir, exist_ok=True) + image_pil.save(f"{test_dir}/{epoch}.png") torch.distributed.barrier() if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Simple example of training script.") + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument("--local_rank", type=int) + parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") + parser.add_argument("--resolution", type=int, default=64) + parser.add_argument("--output_path", type=str, default="ddpm-model") + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--gradient_accumulation_steps", type=int, default=2) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU.", + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU.", ) args = parser.parse_args() diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 81cbdf3641..b3dd5ef64a 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -214,6 +214,21 @@ class PipelineTesterMixin(unittest.TestCase): expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + @slow + def test_glide_text2img(self): + model_id = "fusing/glide-base" + glide = GLIDE.from_pretrained(model_id) + + prompt = "a pencil sketch of a corgi" + generator = torch.manual_seed(0) + image = glide(prompt, generator=generator, num_inference_steps_upscale=20) + + image_slice = image[0, :3, :3, -1].cpu() + + assert image.shape == (1, 256, 256, 3) + expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12) @@ -229,17 +244,3 @@ class PipelineTesterMixin(unittest.TestCase): _ = BDDM.from_pretrained(tmpdirname) # check if the same works using the DifusionPipeline class _ = DiffusionPipeline.from_pretrained(tmpdirname) - @slow - def test_glide_text2img(self): - model_id = "fusing/glide-base" - glide = GLIDE.from_pretrained(model_id) - - prompt = "a pencil sketch of a corgi" - generator = torch.manual_seed(0) - image = glide(prompt, generator=generator, num_inference_steps_upscale=20) - - image_slice = image[0, :3, :3, -1].cpu() - - assert image.shape == (1, 256, 256, 3) - expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2