1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Refactor pipeline outputs, return LDM guidance_scale (#110)

This commit is contained in:
Anton Lozhkov
2022-07-20 17:28:06 +02:00
committed by GitHub
parent 919e27d357
commit 8b4371f70f
9 changed files with 105 additions and 98 deletions

View File

@@ -145,7 +145,7 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

View File

@@ -120,6 +120,7 @@ class DiffusionPipeline(ConfigMixin):
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@@ -131,6 +132,7 @@ class DiffusionPipeline(ConfigMixin):
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
else:
cached_folder = pretrained_model_name_or_path

View File

@@ -53,4 +53,7 @@ class DDIMPipeline(DiffusionPipeline):
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return {"sample": image}

View File

@@ -54,4 +54,7 @@ class DDPMPipeline(DiffusionPipeline):
# 3. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return {"sample": image}

View File

@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
import tqdm
from tqdm.auto import tqdm
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
@@ -35,46 +35,59 @@ class LatentDiffusionPipeline(DiffusionPipeline):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = len(prompt)
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
# get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
image = torch.randn(
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
)
latents = latents.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps):
# 1. predict noise residual
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents
context = text_embeddings
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = torch.cat([latents] * 2)
context = torch.cat([uncond_embeddings, text_embeddings])
if isinstance(pred_noise_t, dict):
pred_noise_t = pred_noise_t["sample"]
# predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, eta)["prev_sample"]
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vqvae.decode(latents)
return image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return {"sample": image}
################################################################################

View File

@@ -28,25 +28,24 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.unet.to(torch_device)
self.vqvae.to(torch_device)
image = torch.randn(
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
)
latents = latents.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):
with torch.no_grad():
model_output = self.unet(image, t)
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
# decode the image latents with the VAE
image = self.vqvae.decode(latents)
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# decode image with vae
with torch.no_grad():
image = self.vqvae.decode(image)
return {"sample": image}

View File

@@ -57,4 +57,7 @@ class PNDMPipeline(DiffusionPipeline):
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return {"sample": image}

View File

@@ -2,14 +2,15 @@
import torch
from diffusers import DiffusionPipeline
from tqdm.auto import tqdm
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
@torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@@ -24,12 +25,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
for _ in range(self.scheduler.correct_steps):
with torch.no_grad():
model_output = self.model(sample, sigma_t)
model_output = self.model(sample, sigma_t)
if isinstance(model_output, dict):
model_output = model_output["sample"]
@@ -45,4 +45,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
output = self.scheduler.step_pred(model_output, t, sample)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
return sample_mean
sample = sample.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
return {"sample": sample}

View File

@@ -741,13 +741,11 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ddim_lsun(self):
@@ -761,13 +759,11 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor(
[-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ddim_cifar10(self):
@@ -781,13 +777,11 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_pndm_cifar10(self):
@@ -800,13 +794,11 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = pndm(generator=generator)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_text2img(self):
@@ -814,13 +806,13 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=20)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_text2img_fast(self):
@@ -828,43 +820,34 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1)
image = ldm([prompt], generator=generator, num_inference_steps=1)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024")
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256")
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024")
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=2)
image = sde_ve(num_inference_steps=300)["sample"]
if model.device.type == "cpu":
# patrick's cpu
expected_image_sum = 3384805888.0
expected_image_mean = 1076.00085
image_slice = image[0, -3:, -3:, -1]
# m1 mbp
# expected_image_sum = 3384805376.0
# expected_image_mean = 1076.000610351562
else:
expected_image_sum = 3382849024.0
expected_image_mean = 1075.3788
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_uncond(self):
@@ -873,10 +856,8 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor(
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2