diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 2b57f21ae5..142727ef2d 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -242,7 +242,7 @@ class UniDiffuserPipeline(DiffusionPipeline): def set_text_mode(self): self.mode = "text" - def set_img_mode(self): + def set_image_mode(self): self.mode = "img" def set_text_to_image_mode(self): @@ -276,7 +276,8 @@ class UniDiffuserPipeline(DiffusionPipeline): batch_size = num_samples return batch_size - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + # self.tokenizer => self.clip_tokenizer def _encode_prompt( self, prompt, @@ -319,25 +320,25 @@ class UniDiffuserPipeline(DiffusionPipeline): batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - text_inputs = self.tokenizer( + text_inputs = self.clip_tokenizer( prompt, padding="max_length", - max_length=self.tokenizer.model_max_length, + max_length=self.clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + removed_text = self.clip_tokenizer.batch_decode( + untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: @@ -380,7 +381,7 @@ class UniDiffuserPipeline(DiffusionPipeline): uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( + uncond_input = self.clip_tokenizer( uncond_tokens, padding="max_length", max_length=max_length, @@ -480,24 +481,21 @@ class UniDiffuserPipeline(DiffusionPipeline): f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - image = image.to(device=device, dtype=dtype) + preprocessed_image = self.image_processor.preprocess( + image, + do_center_crop=True, + crop_size=resolution, + return_tensors="pt", + ) + preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) if isinstance(generator, list): image_latents = [ - self.image_encoder( - **self.image_processor.preprocess( - image[i : i + 1], do_center_crop=True, crop_size=resolution, return_tensors="pt" - ) - ) - for i in range(batch_size) + self.image_encoder(**preprocessed_image[i : i + 1]).pooler_output for i in range(batch_size) ] image_latents = torch.cat(image_latents, dim=0) else: - # TODO: figure out self.image_processor.preprocess kwargs - inputs = self.image_processor.preprocess( - image, do_center_crop=True, crop_size=resolution, return_tensors="pt" - ) - image_latents = self.image_encoder(**inputs) + image_latents = self.image_encoder(**preprocessed_image).pooler_output if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size @@ -659,7 +657,7 @@ class UniDiffuserPipeline(DiffusionPipeline): prompt_embeds, img_vae, img_clip, - timesteps, + max_timestep, guidance_scale, generator, device, @@ -689,17 +687,15 @@ class UniDiffuserPipeline(DiffusionPipeline): img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) - t_img_uncond = torch.ones_like(t) * timesteps[0] - t_text_uncond = torch.ones_like(t) * timesteps[0] # print(f"t_img_uncond: {t_img_uncond}") # print(f"t_img_uncond shape: {t_img_uncond.shape}") # print("Running unconditional U-Net call 1 for CFG...") - _, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=t_img_uncond, t_text=t) + _, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=max_timestep, t_text=t) # print("Running unconditional U-Net call 2 for CFG...") img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( - img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=t_text_uncond + img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep ) x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) @@ -708,10 +704,9 @@ class UniDiffuserPipeline(DiffusionPipeline): elif mode == "text2img": # Text-conditioned image generation img_vae_latents, img_clip_latents = self._split(latents, height, width) - t_text = torch.zeros(t.size(0), dtype=torch.int, device=device) img_vae_out, img_clip_out, text_out = self.unet( - img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text + img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=0 ) img_out = self._combine(img_vae_out, img_clip_out) @@ -721,10 +716,9 @@ class UniDiffuserPipeline(DiffusionPipeline): # Classifier-free guidance text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) - t_text_uncond = torch.ones_like(t) * timesteps img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( - img_vae_latents, img_clip_latents, text_T, t_img=timesteps, t_text=t_text_uncond + img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep ) img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) @@ -732,9 +726,7 @@ class UniDiffuserPipeline(DiffusionPipeline): return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond elif mode == "img2text": # Image-conditioned text generation - t_img = torch.zeros(t.size(0), dtype=torch.int, device=device) - - img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t) + img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=0, t_text=t) if guidance_scale <= 1.0: return text_out @@ -742,27 +734,23 @@ class UniDiffuserPipeline(DiffusionPipeline): # Classifier-free guidance img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) - t_img_uncond = torch.ones_like(t) * timesteps img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( - img_vae_T, img_clip_T, latents, t_img=t_img_uncond, t_text=timesteps + img_vae_T, img_clip_T, latents, t_img=max_timestep, t_text=t ) return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond elif mode == "text": # Unconditional ("marginal") text generation (no CFG) - t_img = torch.ones_like(t) * timesteps - - img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t) + img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=max_timestep, t_text=t) return text_out elif mode == "img": # Unconditional ("marginal") image generation (no CFG) img_vae_latents, img_clip_latents = self._split(latents, height, width) - t_text = torch.ones_like(t) * timesteps img_vae_out, img_clip_out, text_out = self.unet( - img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text + img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=max_timestep ) img_out = self._combine(img_vae_out, img_clip_out) @@ -980,7 +968,7 @@ class UniDiffuserPipeline(DiffusionPipeline): assert image is not None # Encode image using VAE image_vae = preprocess(image) - height, width = image.shape[-2:] + height, width = image_vae.shape[-2:] image_vae_latents = self.encode_image_vae_latents( image_vae, batch_size, @@ -1001,6 +989,8 @@ class UniDiffuserPipeline(DiffusionPipeline): device, generator, ) + # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) + image_clip_latents = image_clip_latents.unsqueeze(1) else: # 4.2. Prepare image latent variables, if input not available # Prepare image VAE latents @@ -1030,6 +1020,7 @@ class UniDiffuserPipeline(DiffusionPipeline): # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + max_timestep = timesteps[0] # print(f"Timesteps: {timesteps}") # print(f"Timesteps shape: {timesteps.shape}") @@ -1062,7 +1053,7 @@ class UniDiffuserPipeline(DiffusionPipeline): prompt_embeds, image_vae_latents, image_clip_latents, - timesteps, + max_timestep, guidance_scale, generator, device, diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 3629bb42ac..ef7a1556c9 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -2,7 +2,6 @@ import random import unittest import numpy as np -import pytest import torch from PIL import Image from transformers import ( @@ -158,7 +157,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): } return inputs - # @pytest.mark.xfail(reason="not finished debugging") def test_unidiffuser_default_joint(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -186,7 +184,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # TODO: need to figure out correct text output print(text) - @pytest.mark.xfail(reason="haven't begun debugging") def test_unidiffuser_default_text2img(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -208,7 +205,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_slice = np.array([0.3965, 0.4568, 0.4495, 0.4590, 0.4463, 0.4690, 0.5454, 0.5093, 0.4321]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - @pytest.mark.xfail(reason="haven't begun debugging") def test_unidiffuser_default_img2text(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -227,8 +223,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # TODO: need to figure out correct text output print(text) + assert 0 == 1 - @pytest.mark.xfail(reason="haven't begun debugging") def test_unidiffuser_default_text(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -248,8 +244,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # TODO: need to figure out correct text output print(text) + assert 0 == 1 - @pytest.mark.xfail(reason="haven't begun debugging") def test_unidiffuser_default_image(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -268,8 +264,8 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image = unidiffuser_pipe(**inputs).images assert image.shape == (1, 32, 32, 3) - # TODO: get expected slice of image output image_slice = image[0, -3:, -3:, -1] + print(image_slice.flatten()) expected_slice = np.array([0.3967, 0.4568, 0.4495, 0.4590, 0.4463, 0.4690, 0.5454, 0.5093, 0.4321]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3