mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Fix noise pred timestep, clip_tokenizer, CLIP image encoding, and other bugs.
This commit is contained in:
@@ -242,7 +242,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
def set_text_mode(self):
|
||||
self.mode = "text"
|
||||
|
||||
def set_img_mode(self):
|
||||
def set_image_mode(self):
|
||||
self.mode = "img"
|
||||
|
||||
def set_text_to_image_mode(self):
|
||||
@@ -276,7 +276,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
batch_size = num_samples
|
||||
return batch_size
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# self.tokenizer => self.clip_tokenizer
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -319,25 +320,25 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
text_inputs = self.clip_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
max_length=self.clip_tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
removed_text = self.clip_tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
@@ -380,7 +381,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_input = self.clip_tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
@@ -480,24 +481,21 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
preprocessed_image = self.image_processor.preprocess(
|
||||
image,
|
||||
do_center_crop=True,
|
||||
crop_size=resolution,
|
||||
return_tensors="pt",
|
||||
)
|
||||
preprocessed_image = preprocessed_image.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
self.image_encoder(
|
||||
**self.image_processor.preprocess(
|
||||
image[i : i + 1], do_center_crop=True, crop_size=resolution, return_tensors="pt"
|
||||
)
|
||||
)
|
||||
for i in range(batch_size)
|
||||
self.image_encoder(**preprocessed_image[i : i + 1]).pooler_output for i in range(batch_size)
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
# TODO: figure out self.image_processor.preprocess kwargs
|
||||
inputs = self.image_processor.preprocess(
|
||||
image, do_center_crop=True, crop_size=resolution, return_tensors="pt"
|
||||
)
|
||||
image_latents = self.image_encoder(**inputs)
|
||||
image_latents = self.image_encoder(**preprocessed_image).pooler_output
|
||||
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand image_latents for batch_size
|
||||
@@ -659,7 +657,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
prompt_embeds,
|
||||
img_vae,
|
||||
img_clip,
|
||||
timesteps,
|
||||
max_timestep,
|
||||
guidance_scale,
|
||||
generator,
|
||||
device,
|
||||
@@ -689,17 +687,15 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
|
||||
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
|
||||
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
|
||||
t_img_uncond = torch.ones_like(t) * timesteps[0]
|
||||
t_text_uncond = torch.ones_like(t) * timesteps[0]
|
||||
|
||||
# print(f"t_img_uncond: {t_img_uncond}")
|
||||
# print(f"t_img_uncond shape: {t_img_uncond.shape}")
|
||||
|
||||
# print("Running unconditional U-Net call 1 for CFG...")
|
||||
_, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=t_img_uncond, t_text=t)
|
||||
_, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=max_timestep, t_text=t)
|
||||
# print("Running unconditional U-Net call 2 for CFG...")
|
||||
img_vae_out_uncond, img_clip_out_uncond, _ = self.unet(
|
||||
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=t_text_uncond
|
||||
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep
|
||||
)
|
||||
|
||||
x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond)
|
||||
@@ -708,10 +704,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
elif mode == "text2img":
|
||||
# Text-conditioned image generation
|
||||
img_vae_latents, img_clip_latents = self._split(latents, height, width)
|
||||
t_text = torch.zeros(t.size(0), dtype=torch.int, device=device)
|
||||
|
||||
img_vae_out, img_clip_out, text_out = self.unet(
|
||||
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text
|
||||
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=0
|
||||
)
|
||||
|
||||
img_out = self._combine(img_vae_out, img_clip_out)
|
||||
@@ -721,10 +716,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
# Classifier-free guidance
|
||||
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
|
||||
t_text_uncond = torch.ones_like(t) * timesteps
|
||||
|
||||
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
|
||||
img_vae_latents, img_clip_latents, text_T, t_img=timesteps, t_text=t_text_uncond
|
||||
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep
|
||||
)
|
||||
|
||||
img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond)
|
||||
@@ -732,9 +726,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond
|
||||
elif mode == "img2text":
|
||||
# Image-conditioned text generation
|
||||
t_img = torch.zeros(t.size(0), dtype=torch.int, device=device)
|
||||
|
||||
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t)
|
||||
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=0, t_text=t)
|
||||
|
||||
if guidance_scale <= 1.0:
|
||||
return text_out
|
||||
@@ -742,27 +734,23 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
# Classifier-free guidance
|
||||
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
|
||||
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
|
||||
t_img_uncond = torch.ones_like(t) * timesteps
|
||||
|
||||
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
|
||||
img_vae_T, img_clip_T, latents, t_img=t_img_uncond, t_text=timesteps
|
||||
img_vae_T, img_clip_T, latents, t_img=max_timestep, t_text=t
|
||||
)
|
||||
|
||||
return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond
|
||||
elif mode == "text":
|
||||
# Unconditional ("marginal") text generation (no CFG)
|
||||
t_img = torch.ones_like(t) * timesteps
|
||||
|
||||
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t)
|
||||
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=max_timestep, t_text=t)
|
||||
|
||||
return text_out
|
||||
elif mode == "img":
|
||||
# Unconditional ("marginal") image generation (no CFG)
|
||||
img_vae_latents, img_clip_latents = self._split(latents, height, width)
|
||||
t_text = torch.ones_like(t) * timesteps
|
||||
|
||||
img_vae_out, img_clip_out, text_out = self.unet(
|
||||
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text
|
||||
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=max_timestep
|
||||
)
|
||||
|
||||
img_out = self._combine(img_vae_out, img_clip_out)
|
||||
@@ -980,7 +968,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
assert image is not None
|
||||
# Encode image using VAE
|
||||
image_vae = preprocess(image)
|
||||
height, width = image.shape[-2:]
|
||||
height, width = image_vae.shape[-2:]
|
||||
image_vae_latents = self.encode_image_vae_latents(
|
||||
image_vae,
|
||||
batch_size,
|
||||
@@ -1001,6 +989,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
# (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size)
|
||||
image_clip_latents = image_clip_latents.unsqueeze(1)
|
||||
else:
|
||||
# 4.2. Prepare image latent variables, if input not available
|
||||
# Prepare image VAE latents
|
||||
@@ -1030,6 +1020,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
# 5. Set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
max_timestep = timesteps[0]
|
||||
# print(f"Timesteps: {timesteps}")
|
||||
# print(f"Timesteps shape: {timesteps.shape}")
|
||||
|
||||
@@ -1062,7 +1053,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
prompt_embeds,
|
||||
image_vae_latents,
|
||||
image_clip_latents,
|
||||
timesteps,
|
||||
max_timestep,
|
||||
guidance_scale,
|
||||
generator,
|
||||
device,
|
||||
|
||||
@@ -2,7 +2,6 @@ import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
@@ -158,7 +157,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
}
|
||||
return inputs
|
||||
|
||||
# @pytest.mark.xfail(reason="not finished debugging")
|
||||
def test_unidiffuser_default_joint(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -186,7 +184,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# TODO: need to figure out correct text output
|
||||
print(text)
|
||||
|
||||
@pytest.mark.xfail(reason="haven't begun debugging")
|
||||
def test_unidiffuser_default_text2img(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -208,7 +205,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
expected_slice = np.array([0.3965, 0.4568, 0.4495, 0.4590, 0.4463, 0.4690, 0.5454, 0.5093, 0.4321])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@pytest.mark.xfail(reason="haven't begun debugging")
|
||||
def test_unidiffuser_default_img2text(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -227,8 +223,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: need to figure out correct text output
|
||||
print(text)
|
||||
assert 0 == 1
|
||||
|
||||
@pytest.mark.xfail(reason="haven't begun debugging")
|
||||
def test_unidiffuser_default_text(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -248,8 +244,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: need to figure out correct text output
|
||||
print(text)
|
||||
assert 0 == 1
|
||||
|
||||
@pytest.mark.xfail(reason="haven't begun debugging")
|
||||
def test_unidiffuser_default_image(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -268,8 +264,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image = unidiffuser_pipe(**inputs).images
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
# TODO: get expected slice of image output
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
print(image_slice.flatten())
|
||||
expected_slice = np.array([0.3967, 0.4568, 0.4495, 0.4590, 0.4463, 0.4690, 0.5454, 0.5093, 0.4321])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
|
||||
Reference in New Issue
Block a user