1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix clip norm

This commit is contained in:
anton-l
2022-11-16 19:59:48 +01:00
parent 9a8114a8d6
commit b17475e6f0
2 changed files with 20 additions and 12 deletions

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import inspect
import PIL
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
import PIL
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel, VQModel
@@ -29,8 +30,8 @@ from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
class VersatileMixedModel:
"""
A context managet that swaps the transformer modules between the image and text unet during inference,
depending on the latent type and condition type.
A context managet that swaps the transformer modules between the image and text unet during inference, depending on
the latent type and condition type.
"""
def __init__(self, image_unet, text_unet, latent_type, condition_type):
@@ -126,6 +127,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
"""
def normalize_embeddings(encoder_output):
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
embeds_pooled = encoder_output.text_embeds
@@ -161,17 +163,20 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
"""
def normalize_embeddings(encoder_output):
embeds = self.image_encoder.visual_projection(encoder_output.last_hidden_state)
embeds_pooled = encoder_output.image_embeds
embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
embeds = self.image_encoder.visual_projection(embeds)
embeds_pooled = embeds[:, 0:1]
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds
batch_size = len(prompt) if isinstance(prompt, list) else 1
if do_classifier_free_guidance:
dummy_images = torch.zeros((batch_size, 3, 224, 224)).to(self.device)
uncond_embeddings = self.image_encoder(dummy_images)
dummy_images = [np.zeros((512, 512, 3))] * batch_size
dummy_images = self.image_processor(images=dummy_images, return_tensors="pt")
uncond_embeddings = self.image_encoder(dummy_images.pixel_values.to(self.device))
uncond_embeddings = normalize_embeddings(uncond_embeddings)
# get prompt text embeddings

View File

@@ -19,7 +19,7 @@ import numpy as np
import torch
from diffusers import VersatileDiffusionPipeline
from diffusers.utils.testing_utils import require_torch, slow, torch_device, load_image
from diffusers.utils.testing_utils import load_image, require_torch, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin
@@ -57,12 +57,15 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
image_prompt = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(
image_prompt=image_prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
image_prompt=image_prompt,
generator=generator,
guidance_scale=7.5,
num_inference_steps=50,
output_type="numpy",
).images
image_slice = image[0, -3:, -3:, -1]