From b17475e6f017af65890ea9cfcead83f5a0753f1a Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 16 Nov 2022 19:59:48 +0100 Subject: [PATCH] fix clip norm --- .../pipeline_versatile_diffusion.py | 21 ++++++++++++------- .../test_versatile_diffusion.py | 11 ++++++---- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index a0f40d0390..b5f3148a0b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -13,12 +13,13 @@ # limitations under the License. import inspect -import PIL from typing import List, Optional, Tuple, Union +import numpy as np import torch import torch.utils.checkpoint +import PIL from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from ...models import AutoencoderKL, UNet2DConditionModel, VQModel @@ -29,8 +30,8 @@ from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler class VersatileMixedModel: """ - A context managet that swaps the transformer modules between the image and text unet during inference, - depending on the latent type and condition type. + A context managet that swaps the transformer modules between the image and text unet during inference, depending on + the latent type and condition type. """ def __init__(self, image_unet, text_unet, latent_type, condition_type): @@ -126,6 +127,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not """ + def normalize_embeddings(encoder_output): embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) embeds_pooled = encoder_output.text_embeds @@ -161,17 +163,20 @@ class VersatileDiffusionPipeline(DiffusionPipeline): do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not """ + def normalize_embeddings(encoder_output): - embeds = self.image_encoder.visual_projection(encoder_output.last_hidden_state) - embeds_pooled = encoder_output.image_embeds - embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds batch_size = len(prompt) if isinstance(prompt, list) else 1 if do_classifier_free_guidance: - dummy_images = torch.zeros((batch_size, 3, 224, 224)).to(self.device) - uncond_embeddings = self.image_encoder(dummy_images) + dummy_images = [np.zeros((512, 512, 3))] * batch_size + dummy_images = self.image_processor(images=dummy_images, return_tensors="pt") + uncond_embeddings = self.image_encoder(dummy_images.pixel_values.to(self.device)) uncond_embeddings = normalize_embeddings(uncond_embeddings) # get prompt text embeddings diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py index 5c37ebabee..4a34264952 100644 --- a/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion.py @@ -19,7 +19,7 @@ import numpy as np import torch from diffusers import VersatileDiffusionPipeline -from diffusers.utils.testing_utils import require_torch, slow, torch_device, load_image +from diffusers.utils.testing_utils import load_image, require_torch, slow, torch_device from ...test_pipelines_common import PipelineTesterMixin @@ -57,12 +57,15 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase): pipe.set_progress_bar_config(disable=None) image_prompt = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" ) generator = torch.Generator(device=torch_device).manual_seed(0) image = pipe( - image_prompt=image_prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" + image_prompt=image_prompt, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", ).images image_slice = image[0, -3:, -3:, -1]