1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Initial commit for a image-text UniDiffuser pipeline.

This commit is contained in:
Daniel Gu
2023-04-03 23:40:20 -07:00
parent 7139f0e874
commit afe5ba0f20
6 changed files with 1987 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
ImageTextPipelineOutput,
UniDiffuserPipeline,
)
else:
from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel, UTransformer2DModel
from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline

View File

@@ -0,0 +1,188 @@
from typing import Optional
import numpy as np
import torch
from torch import nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
tokenizer: GPT2Tokenizer,
text_decoder: GPT2LMHeadModel,
prefix_length: int,
hidden_dim: Optional[int] = None,
use_hidden_dim: bool = True,
):
"""
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
generate text from the UniDiffuser image-text embedding.
Parameters:
tokenizer ([`GPT2Tokenizer`]):
Tokenizer of class
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for
the GPT-like text decoder model.
text_decoder ([`GPT2LMHeadModel`]):
Text decoder model of class
[GPT2LMHeadModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel)
used to generate text from the UniDiffuser text embedding.
prefix_length (`int`):
TODO
hidden_dim (`int`, *optional*):
Hidden dim of the MLP if we encode the prefix.
use_hidden_dim (`bool`, *optional*, defaults to `True`):
Whether or not to use a MLP to encode the prefix.
"""
super().__init__()
self.prefix_length = prefix_length
eos = "<|EOS|>"
special_tokens_dict = {"eos_token": eos}
self.tokenizer = tokenizer
self.tokenizer.add_special_tokens(special_tokens_dict)
self.transformer = text_decoder
# TODO: need to set the eos_token_id correctly
self.transformer.config.eos_token_id = self.tokenizer.eos_token_id
self.transformer.resize_token_embeddings(len(self.tokenizer))
self.use_hidden_dim = use_hidden_dim
self.hidden_dim = hidden_dim if hidden_dim is not None else self.transformer.config.n_embd
self.encode_prefix = nn.Linear(768, self.hidden_dim) if use_hidden_dim else nn.Identity()
self.decode_prefix = nn.Linear(self.hidden_dim, 768) if use_hidden_dim else nn.Identity()
def forward(
self,
tokens: torch.Tensor,
prefix: torch.Tensor,
mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
):
"""
Args:
tokens (`torch.Tensor` of shape `(N, max_seq_len)`):
Text tokens to use for inference.
prefix (`torch.Tensor` of shape `(N, prefix_length, 768)`):
Prefix embedding to preprend to the embedded tokens.
mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*):
Attention mask for the prefix embedding.
labels (`torch.Tensor`, *optional*):
TODO
"""
embedding_text = self.transformer.transformer.wte(tokens)
hidden = self.encode_prefix(prefix)
prefix = self.decode_prefix(hidden)
embedding_cat = torch.cat((prefix, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
if self.use_hidden_dim:
return out, hidden
else:
return out
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
@torch.no_grad()
def generate_captions(self, features, device):
"""
Generate captions given text embedding features. Returns list[L].
Args:
features (`torch.Tensor` of shape `(B, L, D)`):
Text embedding features to generate captions from.
device:
Device to perform text generation on.
"""
features = torch.split(features, 1, dim=0)
generated_captions = []
for feature in features:
feature = self.decode_prefix(feature.to(device)) # back to the clip feature
# Only support beam search for now
generated_captions.append(self.generate_beam(embed=feature, device=device)[0])
return generated_captions
@torch.no_grad()
def generate_beam(
self,
prompt=None,
embed=None,
device=None,
beam_size: int = 5,
entry_length: int = 67,
temperature: float = 1.0,
stop_token: str = "<|EOS|>",
):
# Generates text until stop_token is reached using beam search with the desired beam size.
stop_token_index = self.tokenizer.encode(stop_token)[0]
tokens = None
scores = None
seq_lengths = torch.ones(beam_size, device=device)
is_stopped = torch.zeros(beam_size, device=device, dtype=bool)
if embed is not None:
generated = embed
else:
assert prompt is not None
tokens = torch.tensor(self.tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = self.transformer.transformer.wte(tokens)
for i in range(entry_length):
outputs = self.transformer(input_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
output_texts = [
self.tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return output_texts

View File

@@ -0,0 +1,635 @@
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.attention import AdaLayerNorm, BasicTransformerBlock
from ...models.embeddings import ImagePositionalEmbeddings, PatchEmbed, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModelOutput
from ...utils import deprecate
class SkipBlock(nn.Module):
def __init__(self, dim: int, num_embeds_ada_norm: Optional[int] = None):
super().__init__()
self.skip_linear = nn.Linear(2 * dim, dim)
# Use AdaLayerNorm for now, maybe support using other forms of LayerNorm?
# Original code uses torch.nn.LayerNorm
self.norm = AdaLayerNorm(dim, num_embeds_ada_norm)
def forward(self, x, skip):
x = self.skip_linear(torch.cat([x, skip], dim=-1))
x = self.norm(x)
return x
# Modified from diffusers.models.transformer_2d.Transformer2DModel
# Modify the transformer block structure to be U-Net like following U-ViT
# https://github.com/baofff/U-ViT
class UTransformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared
to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion,
similar to a U-Net. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
" sure that either `num_vector_embeds` or `num_patches` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
elif self.is_input_patches:
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
)
# 3. Define transformers blocks
# Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block,
# and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in
# a "U"-shaped fashion (e.g. first in_block to last out_block, etc.).
self.transformer_in_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
)
for d in range(num_layers // 2)
]
)
self.transformer_mid_block = BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
)
# For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs
# before each transformer out_block.
self.transformer_out_blocks = nn.ModuleList(
[
nn.ModuleDict(
{
"skip": SkipBlock(
inner_dim,
num_embeds_ada_norm=num_embeds_ada_norm,
),
"block": BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
),
}
)
for d in range(num_layers // 2)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches:
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
cross_attention_kwargs=None,
return_dict: bool = True,
hidden_states_is_embedding: bool = False,
):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
transformer blocks.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# 1. Input
if not hidden_states_is_embedding:
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
hidden_states = self.pos_embed(hidden_states)
# 2. Blocks
# In ("downsample") blocks
skips = []
for in_block in self.transformer_in_blocks:
hidden_states = in_block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
skips.append(hidden_states)
# Mid block
hidden_states = self.transformer_mid_block(hidden_states)
# Out ("upsample") blocks
for out_block in self.transformer_in_blocks:
hidden_states = out_block["skip"](hidden_states, skips.pop())
hidden_states = out_block["block"](
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
elif self.is_input_patches:
# TODO: cleanup!
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
# unpatchify
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
class UniDiffuserModel(ModelMixin, ConfigMixin):
"""
Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a
modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the
CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details).
Parameters:
text_dim (`int`): The hidden dimension of the CLIP text model used to embed images.
clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts.
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
text_dim: int = 768,
clip_img_dim: int = 512,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
):
super().__init__()
# 0. Handle dimensions
self.inner_dim = num_attention_heads * attention_head_dim
assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size"
self.sample_size = sample_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.patch_size = patch_size
# Assume image is square...
self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size)
# 1. Define input layers
# For now, only support patch input for VAE latent image input
self.vae_img_in = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
)
self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim)
self.text_in = nn.Linear(text_dim, self.inner_dim)
# Timestep embeddings for t_img, t_text
self.t_img_proj = Timesteps(
self.inner_dim,
flip_sin_to_cos=True,
downscale_freq_shift=0,
)
self.t_img_embed = TimestepEmbedding(
self.inner_dim,
4 * self.inner_dim,
out_dim=self.inner_dim,
)
self.t_text_proj = Timesteps(
self.inner_dim,
flip_sin_to_cos=True,
downscale_freq_shift=0,
)
self.t_text_embed = TimestepEmbedding(
self.inner_dim,
4 * self.inner_dim,
out_dim=self.inner_dim,
)
# 2. Define transformer blocks
self.transformer = UTransformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
patch_size=patch_size,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
)
# 3. Define output layers
self.vae_img_out = nn.Linear(self.inner_dim, self.num_patches)
self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim)
self.text_out = nn.Linear(self.inner_dim, text_dim)
def forward(
self,
img_vae: torch.FloatTensor,
img_clip: torch.FloatTensor,
text: torch.FloatTensor,
t_img: Union[torch.Tensor, float, int],
t_text: Union[torch.Tensor, float, int],
encoder_hidden_states=None,
timestep=None,
class_labels=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
"""
Args:
img_vae (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`):
Latent image representation from the VAE encoder.
img_clip (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`):
CLIP-embedded image representation (unsqueezed in the first dimension).
text (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`):
CLIP-embedded text representation.
t_img (`torch.long` or `float` or `int`):
Current denoising step for the image.
t_text (`torch.long` or `float` or `int`):
Current denoising step for the text.
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
batch_size = img_vae.shape[0]
# 1. Input
# 1.1. Map inputs to shape (B, N, inner_dim)
vae_hidden_states = self.vae_img_in(img_vae)
clip_hidden_states = self.clip_img_in(img_clip)
text_hidden_states = self.text_in(text)
num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1)
# 1.2. Encode image and text timesteps
# t_img
if not torch.is_tensor(t_img):
t_img = torch.tensor([t_img], dtype=torch.long, device=vae_hidden_states.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
t_img = t_img * torch.ones(batch_size, dtype=t_img.dtype, device=t_img.device)
t_img_token = self.t_img_proj(t_img)
# t_img_token does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
t_img_token = t_img_token.to(dtype=self.dtype)
t_img_token = self.t_img_embed(t_img_token)
t_img_token = t_img_token.unsqueeze(dim=1)
# t_text
if not torch.is_tensor(t_text):
t_text = torch.tensor([t_text], dtype=torch.long, device=vae_hidden_states.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
t_text = t_text * torch.ones(batch_size, dtype=t_text.dtype, device=t_text.device)
t_text_token = self.t_text_proj(t_text)
# t_text_token does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
t_text_token = t_text_token.to(dtype=self.dtype)
t_text_token = self.t_text_embed(t_text_token)
t_text_token = t_text_token.unsqueeze(dim=1)
# 1.3. Concatenate all of the embeddings together.
hidden_states = torch.cat(
[t_img_token, t_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], dim=1
)
# 2. Blocks
hidden_states = self.transformer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
class_labels=class_labels,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=return_dict,
hidden_states_is_embedding=True,
)
# 3. Output
# Split out the predicted noise representation.
t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split(
(1, 1, num_text_tokens, 1, num_img_tokens), dim=1
)
img_vae_out = self.vae_img_out(img_vae_out)
# unpatchify
height = width = int(img_vae_out.shape[1] ** 0.5)
img_vae_out = img_vae_out.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out)
img_vae_out = img_vae_out.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
img_clip_out = self.clip_img_out(img_clip_out)
text_out = self.text_out(text_out)
return img_vae_out, img_clip_out, text_out

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,119 @@
import unittest
import torch
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModel,
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer,
)
from diffusers import (
AutoencoderKL,
DDIMScheduler,
UniDiffuserModel,
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from diffusers.utils import slow
from diffusers.utils.testing_utils import require_torch_gpu
from ...test_pipelines_common import PipelineTesterMixin
class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = UniDiffuserPipeline
params = None # TODO
def get_dummy_components(self):
torch.manual_seed(0)
unet = UniDiffuserModel(
sample_size=16,
num_layers=2,
patch_size=4,
attention_head_dim=8,
num_attention_heads=2,
in_channels=4,
out_channels=8,
attention_bias=True,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_zero",
norm_elementwise_affine=False,
text_dim=32, # TODO: needs to line up with CLIPTextConfig
clip_img_dim=32, # TODO: needs to line up with CLIPVisionConfig
)
scheduler = DDIMScheduler()
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
# TODO: get appropriate testing version for these
text_decoder_tokenizer = GPT2Tokenizer()
text_decoder_model_config = GPT2Config()
text_decoder_model = GPT2LMHeadModel(text_decoder_model_config)
text_decoder = UniDiffuserTextDecoder(
text_decoder_tokenizer,
text_decoder_model,
prefix_length=77, # TODO: fix
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig()
image_encoder = CLIPVisionModel(image_encoder_config)
# TODO: does this actually work?
image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"vae": vae,
"text_encoder": text_encoder,
"text_decoder": text_decoder,
"image_encoder": image_encoder,
"tokenizer": tokenizer,
"image_processor": image_processor,
"unet": unet,
"scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
pass
def test_unidiffuser_default_case(self):
pass
@slow
@require_torch_gpu
class UniDiffuserPipelineSlowTests(unittest.TestCase):
pass