mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[refactor embeddings] gligen + ip-adapter (#6244)
* refactor ip-adapter-imageproj, gligen --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
|
||||
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -712,7 +712,7 @@ class UNet2DConditionLoadersMixin:
|
||||
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||
|
||||
image_projection = MLPProjection(
|
||||
image_projection = IPAdapterFullImageProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
||||
)
|
||||
|
||||
@@ -730,7 +730,7 @@ class UNet2DConditionLoadersMixin:
|
||||
hidden_dims = state_dict["latents"].shape[2]
|
||||
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
image_projection = Resampler(
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
@@ -780,7 +780,7 @@ class UNet2DConditionLoadersMixin:
|
||||
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
|
||||
|
||||
# Set encoder_hid_proj after loading ip_adapter weights,
|
||||
# because `Resampler` also has `attn_processors`.
|
||||
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
|
||||
@@ -462,7 +462,7 @@ class ImageProjection(nn.Module):
|
||||
return image_embeds
|
||||
|
||||
|
||||
class MLPProjection(nn.Module):
|
||||
class IPAdapterFullImageProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
@@ -621,29 +621,34 @@ class AttentionPooling(nn.Module):
|
||||
return a[:, 0, :] # cls_token
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
def __init__(self, num_freqs=64, temperature=100):
|
||||
super().__init__()
|
||||
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
||||
"""
|
||||
Args:
|
||||
embed_dim: int
|
||||
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
|
||||
Returns:
|
||||
[B x N x embed_dim] tensor of positional embeddings
|
||||
"""
|
||||
|
||||
self.num_freqs = num_freqs
|
||||
self.temperature = temperature
|
||||
batch_size, num_boxes = box.shape[:2]
|
||||
|
||||
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
||||
freq_bands = freq_bands[None, None, None]
|
||||
self.register_buffer("freq_bands", freq_bands, persistent=False)
|
||||
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
|
||||
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
|
||||
emb = emb * box.unsqueeze(-1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.freq_bands * x.unsqueeze(-1)
|
||||
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
||||
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
|
||||
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
class PositionNet(nn.Module):
|
||||
class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
|
||||
super().__init__()
|
||||
self.positive_len = positive_len
|
||||
self.out_dim = out_dim
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
||||
self.fourier_embedder_dim = fourier_freqs
|
||||
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
||||
|
||||
if isinstance(out_dim, tuple):
|
||||
@@ -692,7 +697,7 @@ class PositionNet(nn.Module):
|
||||
masks = masks.unsqueeze(-1)
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
|
||||
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||
@@ -787,7 +792,7 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
class IPAdapterPlusImageProjection(nn.Module):
|
||||
"""Resampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -32,10 +32,10 @@ from .attention_processor import (
|
||||
)
|
||||
from .embeddings import (
|
||||
GaussianFourierProjection,
|
||||
GLIGENTextBoundingboxProjection,
|
||||
ImageHintTimeEmbedding,
|
||||
ImageProjection,
|
||||
ImageTimeEmbedding,
|
||||
PositionNet,
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
@@ -615,7 +615,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
positive_len = cross_attention_dim[0]
|
||||
|
||||
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
||||
self.position_net = PositionNet(
|
||||
self.position_net = GLIGENTextBoundingboxProjection(
|
||||
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
||||
)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class FourierEmbedder(nn.Module):
|
||||
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
||||
|
||||
|
||||
class PositionNet(nn.Module):
|
||||
class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
|
||||
super().__init__()
|
||||
self.positive_len = positive_len
|
||||
@@ -820,7 +820,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
positive_len = cross_attention_dim[0]
|
||||
|
||||
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
||||
self.position_net = PositionNet(
|
||||
self.position_net = GLIGENTextBoundingboxProjection(
|
||||
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
||||
)
|
||||
|
||||
|
||||
@@ -730,7 +730,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
)
|
||||
gligen_phrases = gligen_phrases[:max_objs]
|
||||
gligen_boxes = gligen_boxes[:max_objs]
|
||||
# prepare batched input to the PositionNet (boxes, phrases, mask)
|
||||
# prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
|
||||
# Get tokens for phrases from pre-trained CLIPTokenizer
|
||||
tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device)
|
||||
# For the token, we use the same pre-trained text encoder
|
||||
|
||||
@@ -26,7 +26,7 @@ from pytest import mark
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
|
||||
from diffusers.models.embeddings import ImageProjection, Resampler
|
||||
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model):
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
cross_attention_dim = model.config["cross_attention_dim"]
|
||||
image_projection = Resampler(
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user