1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Fix noise pred timestep, clip_tokenizer, CLIP image encoding, and other bugs.

This commit is contained in:
Daniel Gu
2023-04-17 19:06:19 -07:00
parent 0300563861
commit 84781fbd67
2 changed files with 36 additions and 49 deletions

View File

@@ -242,7 +242,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
def set_text_mode(self):
self.mode = "text"
def set_img_mode(self):
def set_image_mode(self):
self.mode = "img"
def set_text_to_image_mode(self):
@@ -276,7 +276,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
batch_size = num_samples
return batch_size
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# self.tokenizer => self.clip_tokenizer
def _encode_prompt(
self,
prompt,
@@ -319,25 +320,25 @@ class UniDiffuserPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = self.tokenizer(
text_inputs = self.clip_tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
max_length=self.clip_tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
removed_text = self.clip_tokenizer.batch_decode(
untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
@@ -380,7 +381,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_input = self.clip_tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
@@ -480,24 +481,21 @@ class UniDiffuserPipeline(DiffusionPipeline):
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
preprocessed_image = self.image_processor.preprocess(
image,
do_center_crop=True,
crop_size=resolution,
return_tensors="pt",
)
preprocessed_image = preprocessed_image.to(device=device, dtype=dtype)
if isinstance(generator, list):
image_latents = [
self.image_encoder(
**self.image_processor.preprocess(
image[i : i + 1], do_center_crop=True, crop_size=resolution, return_tensors="pt"
)
)
for i in range(batch_size)
self.image_encoder(**preprocessed_image[i : i + 1]).pooler_output for i in range(batch_size)
]
image_latents = torch.cat(image_latents, dim=0)
else:
# TODO: figure out self.image_processor.preprocess kwargs
inputs = self.image_processor.preprocess(
image, do_center_crop=True, crop_size=resolution, return_tensors="pt"
)
image_latents = self.image_encoder(**inputs)
image_latents = self.image_encoder(**preprocessed_image).pooler_output
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size
@@ -659,7 +657,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
prompt_embeds,
img_vae,
img_clip,
timesteps,
max_timestep,
guidance_scale,
generator,
device,
@@ -689,17 +687,15 @@ class UniDiffuserPipeline(DiffusionPipeline):
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
t_img_uncond = torch.ones_like(t) * timesteps[0]
t_text_uncond = torch.ones_like(t) * timesteps[0]
# print(f"t_img_uncond: {t_img_uncond}")
# print(f"t_img_uncond shape: {t_img_uncond.shape}")
# print("Running unconditional U-Net call 1 for CFG...")
_, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=t_img_uncond, t_text=t)
_, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=max_timestep, t_text=t)
# print("Running unconditional U-Net call 2 for CFG...")
img_vae_out_uncond, img_clip_out_uncond, _ = self.unet(
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=t_text_uncond
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep
)
x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond)
@@ -708,10 +704,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
elif mode == "text2img":
# Text-conditioned image generation
img_vae_latents, img_clip_latents = self._split(latents, height, width)
t_text = torch.zeros(t.size(0), dtype=torch.int, device=device)
img_vae_out, img_clip_out, text_out = self.unet(
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=0
)
img_out = self._combine(img_vae_out, img_clip_out)
@@ -721,10 +716,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
# Classifier-free guidance
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
t_text_uncond = torch.ones_like(t) * timesteps
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
img_vae_latents, img_clip_latents, text_T, t_img=timesteps, t_text=t_text_uncond
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep
)
img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond)
@@ -732,9 +726,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond
elif mode == "img2text":
# Image-conditioned text generation
t_img = torch.zeros(t.size(0), dtype=torch.int, device=device)
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t)
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=0, t_text=t)
if guidance_scale <= 1.0:
return text_out
@@ -742,27 +734,23 @@ class UniDiffuserPipeline(DiffusionPipeline):
# Classifier-free guidance
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
t_img_uncond = torch.ones_like(t) * timesteps
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
img_vae_T, img_clip_T, latents, t_img=t_img_uncond, t_text=timesteps
img_vae_T, img_clip_T, latents, t_img=max_timestep, t_text=t
)
return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond
elif mode == "text":
# Unconditional ("marginal") text generation (no CFG)
t_img = torch.ones_like(t) * timesteps
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t)
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=max_timestep, t_text=t)
return text_out
elif mode == "img":
# Unconditional ("marginal") image generation (no CFG)
img_vae_latents, img_clip_latents = self._split(latents, height, width)
t_text = torch.ones_like(t) * timesteps
img_vae_out, img_clip_out, text_out = self.unet(
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=max_timestep
)
img_out = self._combine(img_vae_out, img_clip_out)
@@ -980,7 +968,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
assert image is not None
# Encode image using VAE
image_vae = preprocess(image)
height, width = image.shape[-2:]
height, width = image_vae.shape[-2:]
image_vae_latents = self.encode_image_vae_latents(
image_vae,
batch_size,
@@ -1001,6 +989,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
device,
generator,
)
# (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size)
image_clip_latents = image_clip_latents.unsqueeze(1)
else:
# 4.2. Prepare image latent variables, if input not available
# Prepare image VAE latents
@@ -1030,6 +1020,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
# 5. Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
max_timestep = timesteps[0]
# print(f"Timesteps: {timesteps}")
# print(f"Timesteps shape: {timesteps.shape}")
@@ -1062,7 +1053,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
prompt_embeds,
image_vae_latents,
image_clip_latents,
timesteps,
max_timestep,
guidance_scale,
generator,
device,

View File

@@ -2,7 +2,6 @@ import random
import unittest
import numpy as np
import pytest
import torch
from PIL import Image
from transformers import (
@@ -158,7 +157,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
}
return inputs
# @pytest.mark.xfail(reason="not finished debugging")
def test_unidiffuser_default_joint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -186,7 +184,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# TODO: need to figure out correct text output
print(text)
@pytest.mark.xfail(reason="haven't begun debugging")
def test_unidiffuser_default_text2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -208,7 +205,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array([0.3965, 0.4568, 0.4495, 0.4590, 0.4463, 0.4690, 0.5454, 0.5093, 0.4321])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@pytest.mark.xfail(reason="haven't begun debugging")
def test_unidiffuser_default_img2text(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -227,8 +223,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# TODO: need to figure out correct text output
print(text)
assert 0 == 1
@pytest.mark.xfail(reason="haven't begun debugging")
def test_unidiffuser_default_text(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -248,8 +244,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# TODO: need to figure out correct text output
print(text)
assert 0 == 1
@pytest.mark.xfail(reason="haven't begun debugging")
def test_unidiffuser_default_image(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -268,8 +264,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image = unidiffuser_pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
# TODO: get expected slice of image output
image_slice = image[0, -3:, -3:, -1]
print(image_slice.flatten())
expected_slice = np.array([0.3967, 0.4568, 0.4495, 0.4590, 0.4463, 0.4690, 0.5454, 0.5093, 0.4321])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3