From 8b4371f70fb5e791f4467a30375ef226bc5186a9 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Wed, 20 Jul 2022 17:28:06 +0200 Subject: [PATCH] Refactor pipeline outputs, return LDM guidance_scale (#110) --- src/diffusers/models/vae.py | 2 +- src/diffusers/pipeline_utils.py | 2 + src/diffusers/pipelines/ddim/pipeline_ddim.py | 3 + src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 3 + .../pipeline_latent_diffusion.py | 61 +++++++----- .../pipeline_latent_diffusion_uncond.py | 23 +++-- src/diffusers/pipelines/pndm/pipeline_pndm.py | 3 + .../score_sde_ve/pipeline_score_sde_ve.py | 13 ++- tests/test_modeling_utils.py | 93 ++++++++----------- 9 files changed, 105 insertions(+), 98 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 95b22fcd59..d16ab792f5 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -145,7 +145,7 @@ class Decoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + # print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index b0ff25c339..cbf2252c18 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -120,6 +120,7 @@ class DiffusionPipeline(ConfigMixin): proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -131,6 +132,7 @@ class DiffusionPipeline(ConfigMixin): proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, + revision=revision, ) else: cached_folder = pretrained_model_name_or_path diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 8dba08f728..513fef0e42 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -53,4 +53,7 @@ class DDIMPipeline(DiffusionPipeline): # do x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + return {"sample": image} diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index e72b05cf89..a49a60dc84 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -54,4 +54,7 @@ class DDPMPipeline(DiffusionPipeline): # 3. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + return {"sample": image} diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index fb85608ae6..f41ab01e15 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint -import tqdm +from tqdm.auto import tqdm from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import BaseModelOutput @@ -35,46 +35,59 @@ class LatentDiffusionPipeline(DiffusionPipeline): if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" + batch_size = len(prompt) self.unet.to(torch_device) self.vqvae.to(torch_device) self.bert.to(torch_device) - # get unconditional embeddings for classifier free guidence + # get unconditional embeddings for classifier free guidance if guidance_scale != 1.0: - uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to( - torch_device - ) - uncond_embeddings = self.bert(uncond_input.input_ids) + uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device)) - # get text embedding - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) - text_embedding = self.bert(text_input.input_ids) + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = self.bert(text_input.input_ids.to(torch_device)) - image = torch.randn( + latents = torch.randn( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), generator=generator, - ).to(torch_device) + ) + latents = latents.to(torch_device) self.scheduler.set_timesteps(num_inference_steps) - for t in tqdm.tqdm(self.scheduler.timesteps): - # 1. predict noise residual - pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding) + for t in tqdm(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = text_embeddings + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([uncond_embeddings, text_embeddings]) - if isinstance(pred_noise_t, dict): - pred_noise_t = pred_noise_t["sample"] + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"] + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - # 2. predict previous mean of image x_t-1 and add variance depending on eta - # do x_t -> x_t-1 - image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"] + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, eta)["prev_sample"] - # scale and decode image with vae - image = 1 / 0.18215 * image - image = self.vqvae.decode(image) - image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vqvae.decode(latents) - return image + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + return {"sample": image} ################################################################################ diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 79cf799f5b..38b4eb0517 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -28,25 +28,24 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): self.unet.to(torch_device) self.vqvae.to(torch_device) - image = torch.randn( + latents = torch.randn( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), generator=generator, - ).to(torch_device) + ) + latents = latents.to(torch_device) self.scheduler.set_timesteps(num_inference_steps) for t in tqdm(self.scheduler.timesteps): - with torch.no_grad(): - model_output = self.unet(image, t) + # predict the noise residual + noise_prediction = self.unet(latents, t)["sample"] + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"] - if isinstance(model_output, dict): - model_output = model_output["sample"] + # decode the image latents with the VAE + image = self.vqvae.decode(latents) - # 2. predict previous mean of image x_t-1 and add variance depending on eta - # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() - # decode image with vae - with torch.no_grad(): - image = self.vqvae.decode(image) return {"sample": image} diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index e7d4eb8ab5..69d6db6619 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -57,4 +57,7 @@ class PNDMPipeline(DiffusionPipeline): image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + return {"sample": image} diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 80b0f3ef9a..342847e2d1 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -2,14 +2,15 @@ import torch from diffusers import DiffusionPipeline +from tqdm.auto import tqdm -# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names class ScoreSdeVePipeline(DiffusionPipeline): def __init__(self, model, scheduler): super().__init__() self.register_modules(model=model, scheduler=scheduler) + @torch.no_grad() def __call__(self, num_inference_steps=2000, generator=None): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @@ -24,12 +25,11 @@ class ScoreSdeVePipeline(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps) - for i, t in enumerate(self.scheduler.timesteps): + for i, t in tqdm(enumerate(self.scheduler.timesteps)): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) for _ in range(self.scheduler.correct_steps): - with torch.no_grad(): - model_output = self.model(sample, sigma_t) + model_output = self.model(sample, sigma_t) if isinstance(model_output, dict): model_output = model_output["sample"] @@ -45,4 +45,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): output = self.scheduler.step_pred(model_output, t, sample) sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] - return sample_mean + sample = sample.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + + return {"sample": sample} diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index dc7f125476..9c92f25104 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -741,13 +741,11 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) image = ddpm(generator=generator)["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ddim_lsun(self): @@ -761,13 +759,11 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) image = ddpm(generator=generator)["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor( - [-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ddim_cifar10(self): @@ -781,13 +777,11 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) image = ddim(generator=generator, eta=0.0)["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_pndm_cifar10(self): @@ -800,13 +794,11 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) image = pndm(generator=generator)["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_text2img(self): @@ -814,13 +806,13 @@ class PipelineTesterMixin(unittest.TestCase): prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=20) + image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20)["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_text2img_fast(self): @@ -828,43 +820,34 @@ class PipelineTesterMixin(unittest.TestCase): prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=1) + image = ldm([prompt], generator=generator, num_inference_steps=1)["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_score_sde_ve_pipeline(self): - model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024") + model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256") torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024") + scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256") sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) torch.manual_seed(0) - image = sde_ve(num_inference_steps=2) + image = sde_ve(num_inference_steps=300)["sample"] - if model.device.type == "cpu": - # patrick's cpu - expected_image_sum = 3384805888.0 - expected_image_mean = 1076.00085 + image_slice = image[0, -3:, -3:, -1] - # m1 mbp - # expected_image_sum = 3384805376.0 - # expected_image_mean = 1076.000610351562 - else: - expected_image_sum = 3382849024.0 - expected_image_mean = 1075.3788 - - assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 - assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_uncond(self): @@ -873,10 +856,8 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) image = ldm(generator=generator, num_inference_steps=5)["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor( - [-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2