mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
cleanup LDM
This commit is contained in:
@@ -863,6 +863,7 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
super().__init__()
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
@@ -873,6 +874,7 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidence
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
|
||||
@@ -901,19 +903,23 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# 1. predict noise residual
|
||||
# guidance_scale of 1 means no guidance
|
||||
if guidance_scale == 1.0:
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
context = text_embedding
|
||||
image_in = image
|
||||
else:
|
||||
image_in = torch.cat([image] * 2)
|
||||
context = text_embedding
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
else:
|
||||
# for classifier free guidance, we need to do two forward passes
|
||||
# here we concanate embedding and unconditioned embedding in a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_in = torch.cat([image] * 2)
|
||||
context = torch.cat([uncond_embeddings, text_embedding])
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
|
||||
# 1. predict noise residual
|
||||
pred_noise_t = self.unet(image_in, timesteps, context=context)
|
||||
|
||||
with torch.no_grad():
|
||||
pred_noise_t = self.unet(image_in, timesteps, context=context)
|
||||
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
|
||||
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
|
||||
@@ -933,18 +939,15 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
|
||||
|
||||
# Second: Clip "predicted x_0"
|
||||
# pred_original_image = torch.clamp(pred_original_image, -1, 1)
|
||||
|
||||
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# Second: Compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
|
||||
std_dev_t = eta * std_dev_t
|
||||
|
||||
# Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Third: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
|
||||
|
||||
# Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Forth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
|
||||
|
||||
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
|
||||
@@ -958,9 +961,9 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
# 6. Set current image to prev_image: x_t -> x_t-1
|
||||
image = prev_image
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
|
||||
image = 255. * image
|
||||
|
||||
return image
|
||||
|
||||
Reference in New Issue
Block a user