mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix clip norm
This commit is contained in:
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import PIL
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import PIL
|
||||
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel, VQModel
|
||||
@@ -29,8 +30,8 @@ from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
|
||||
class VersatileMixedModel:
|
||||
"""
|
||||
A context managet that swaps the transformer modules between the image and text unet during inference,
|
||||
depending on the latent type and condition type.
|
||||
A context managet that swaps the transformer modules between the image and text unet during inference, depending on
|
||||
the latent type and condition type.
|
||||
"""
|
||||
|
||||
def __init__(self, image_unet, text_unet, latent_type, condition_type):
|
||||
@@ -126,6 +127,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
"""
|
||||
|
||||
def normalize_embeddings(encoder_output):
|
||||
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
|
||||
embeds_pooled = encoder_output.text_embeds
|
||||
@@ -161,17 +163,20 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
"""
|
||||
|
||||
def normalize_embeddings(encoder_output):
|
||||
embeds = self.image_encoder.visual_projection(encoder_output.last_hidden_state)
|
||||
embeds_pooled = encoder_output.image_embeds
|
||||
embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
|
||||
embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
|
||||
embeds = self.image_encoder.visual_projection(embeds)
|
||||
embeds_pooled = embeds[:, 0:1]
|
||||
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
|
||||
return embeds
|
||||
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
dummy_images = torch.zeros((batch_size, 3, 224, 224)).to(self.device)
|
||||
uncond_embeddings = self.image_encoder(dummy_images)
|
||||
dummy_images = [np.zeros((512, 512, 3))] * batch_size
|
||||
dummy_images = self.image_processor(images=dummy_images, return_tensors="pt")
|
||||
uncond_embeddings = self.image_encoder(dummy_images.pixel_values.to(self.device))
|
||||
uncond_embeddings = normalize_embeddings(uncond_embeddings)
|
||||
|
||||
# get prompt text embeddings
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import VersatileDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device, load_image
|
||||
from diffusers.utils.testing_utils import load_image, require_torch, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -57,12 +57,15 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image_prompt = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||
)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe(
|
||||
image_prompt=image_prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
||||
image_prompt=image_prompt,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
Reference in New Issue
Block a user