mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Refactor pipeline outputs, return LDM guidance_scale (#110)
This commit is contained in:
@@ -145,7 +145,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
# print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
@@ -120,6 +120,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
@@ -131,6 +132,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
@@ -53,4 +53,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -54,4 +54,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
# 3. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
@@ -35,46 +35,59 @@ class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
batch_size = len(prompt)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidence
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)
|
||||
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
|
||||
|
||||
# get text embedding
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
|
||||
|
||||
image = torch.randn(
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
)
|
||||
latents = latents.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||
# 1. predict noise residual
|
||||
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
if guidance_scale == 1.0:
|
||||
# guidance_scale of 1 means no guidance
|
||||
latents_input = latents
|
||||
context = text_embeddings
|
||||
else:
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
if isinstance(pred_noise_t, dict):
|
||||
pred_noise_t = pred_noise_t["sample"]
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, eta)["prev_sample"]
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vqvae.decode(latents)
|
||||
|
||||
return image
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
@@ -28,25 +28,24 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
|
||||
image = torch.randn(
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
)
|
||||
latents = latents.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
with torch.no_grad():
|
||||
model_output = self.unet(image, t)
|
||||
# predict the noise residual
|
||||
noise_prediction = self.unet(latents, t)["sample"]
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
|
||||
|
||||
if isinstance(model_output, dict):
|
||||
model_output = model_output["sample"]
|
||||
# decode the image latents with the VAE
|
||||
image = self.vqvae.decode(latents)
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# decode image with vae
|
||||
with torch.no_grad():
|
||||
image = self.vqvae.decode(image)
|
||||
return {"sample": image}
|
||||
|
||||
@@ -57,4 +57,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
|
||||
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
|
||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, num_inference_steps=2000, generator=None):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
@@ -24,12 +25,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
|
||||
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
with torch.no_grad():
|
||||
model_output = self.model(sample, sigma_t)
|
||||
model_output = self.model(sample, sigma_t)
|
||||
|
||||
if isinstance(model_output, dict):
|
||||
model_output = model_output["sample"]
|
||||
@@ -45,4 +45,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
output = self.scheduler.step_pred(model_output, t, sample)
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
return sample_mean
|
||||
sample = sample.clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
return {"sample": sample}
|
||||
|
||||
@@ -741,13 +741,11 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddim_lsun(self):
|
||||
@@ -761,13 +759,11 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddim_cifar10(self):
|
||||
@@ -781,13 +777,11 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddim(generator=generator, eta=0.0)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_pndm_cifar10(self):
|
||||
@@ -800,13 +794,11 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img(self):
|
||||
@@ -814,13 +806,13 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm([prompt], generator=generator, num_inference_steps=20)
|
||||
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img_fast(self):
|
||||
@@ -828,43 +820,34 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm([prompt], generator=generator, num_inference_steps=1)
|
||||
image = ldm([prompt], generator=generator, num_inference_steps=1)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024")
|
||||
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256")
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024")
|
||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=2)
|
||||
image = sde_ve(num_inference_steps=300)["sample"]
|
||||
|
||||
if model.device.type == "cpu":
|
||||
# patrick's cpu
|
||||
expected_image_sum = 3384805888.0
|
||||
expected_image_mean = 1076.00085
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# m1 mbp
|
||||
# expected_image_sum = 3384805376.0
|
||||
# expected_image_mean = 1076.000610351562
|
||||
else:
|
||||
expected_image_sum = 3382849024.0
|
||||
expected_image_mean = 1075.3788
|
||||
|
||||
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
|
||||
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_uncond(self):
|
||||
@@ -873,10 +856,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=5)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user