diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 5d4c7429e4..11a32a92ae 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args from torch import nn -from ..models.embeddings import ImageProjection, MLPProjection, Resampler +from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( USE_PEFT_BACKEND, @@ -712,7 +712,7 @@ class UNet2DConditionLoadersMixin: clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] cross_attention_dim = state_dict["proj.3.weight"].shape[0] - image_projection = MLPProjection( + image_projection = IPAdapterFullImageProjection( cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim ) @@ -730,7 +730,7 @@ class UNet2DConditionLoadersMixin: hidden_dims = state_dict["latents"].shape[2] heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 - image_projection = Resampler( + image_projection = IPAdapterPlusImageProjection( embed_dims=embed_dims, output_dims=output_dims, hidden_dims=hidden_dims, @@ -780,7 +780,7 @@ class UNet2DConditionLoadersMixin: num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] # Set encoder_hid_proj after loading ip_adapter weights, - # because `Resampler` also has `attn_processors`. + # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None # set ip-adapter cross-attention processors & load state_dict diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7e98f77baf..293b751cb6 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -462,7 +462,7 @@ class ImageProjection(nn.Module): return image_embeds -class MLPProjection(nn.Module): +class IPAdapterFullImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): super().__init__() from .attention import FeedForward @@ -621,29 +621,34 @@ class AttentionPooling(nn.Module): return a[:, 0, :] # cls_token -class FourierEmbedder(nn.Module): - def __init__(self, num_freqs=64, temperature=100): - super().__init__() +def get_fourier_embeds_from_boundingbox(embed_dim, box): + """ + Args: + embed_dim: int + box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline + Returns: + [B x N x embed_dim] tensor of positional embeddings + """ - self.num_freqs = num_freqs - self.temperature = temperature + batch_size, num_boxes = box.shape[:2] - freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) - freq_bands = freq_bands[None, None, None] - self.register_buffer("freq_bands", freq_bands, persistent=False) + emb = 100 ** (torch.arange(embed_dim) / embed_dim) + emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) + emb = emb * box.unsqueeze(-1) - def __call__(self, x): - x = self.freq_bands * x.unsqueeze(-1) - return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) + emb = torch.stack((emb.sin(), emb.cos()), dim=-1) + emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) + + return emb -class PositionNet(nn.Module): +class GLIGENTextBoundingboxProjection(nn.Module): def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): super().__init__() self.positive_len = positive_len self.out_dim = out_dim - self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.fourier_embedder_dim = fourier_freqs self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy if isinstance(out_dim, tuple): @@ -692,7 +697,7 @@ class PositionNet(nn.Module): masks = masks.unsqueeze(-1) # embedding position (it may includes padding as placeholder) - xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C + xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C # learnable null embedding xyxy_null = self.null_position_feature.view(1, 1, -1) @@ -787,7 +792,7 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states -class Resampler(nn.Module): +class IPAdapterPlusImageProjection(nn.Module): """Resampler of IP-Adapter Plus. Args: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ddf533d3bd..623e4d88d5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -32,10 +32,10 @@ from .attention_processor import ( ) from .embeddings import ( GaussianFourierProjection, + GLIGENTextBoundingboxProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, - PositionNet, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, @@ -615,7 +615,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" - self.position_net = PositionNet( + self.position_net = GLIGENTextBoundingboxProjection( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 112aa42323..7c9936a0bd 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -187,7 +187,7 @@ class FourierEmbedder(nn.Module): return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) -class PositionNet(nn.Module): +class GLIGENTextBoundingboxProjection(nn.Module): def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8): super().__init__() self.positive_len = positive_len @@ -820,7 +820,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" - self.position_net = PositionNet( + self.position_net = GLIGENTextBoundingboxProjection( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py index 91d7357fd3..632e696392 100644 --- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -730,7 +730,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline): ) gligen_phrases = gligen_phrases[:max_objs] gligen_boxes = gligen_boxes[:max_objs] - # prepare batched input to the PositionNet (boxes, phrases, mask) + # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask) # Get tokens for phrases from pre-trained CLIPTokenizer tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device) # For the token, we use the same pre-trained text encoder diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 35ea33328c..0e2a4765d6 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -26,7 +26,7 @@ from pytest import mark from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor -from diffusers.models.embeddings import ImageProjection, Resampler +from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model): # "image_proj" (ImageProjection layer weights) cross_attention_dim = model.config["cross_attention_dim"] - image_projection = Resampler( + image_projection = IPAdapterPlusImageProjection( embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4 )