mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add/edit docstrings for added classes and public pipeline methods. Also do some light refactoring.
This commit is contained in:
@@ -12,6 +12,55 @@ from ...models import ModelMixin
|
||||
|
||||
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
|
||||
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
"""
|
||||
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
|
||||
generate text from the UniDiffuser image-text embedding.
|
||||
|
||||
Parameters:
|
||||
prefix_length (`int`):
|
||||
Max number of prefix tokens that will be supplied to the model.
|
||||
prefix_inner_dim (`int`):
|
||||
The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
|
||||
CLIP text encoder.
|
||||
prefix_hidden_dim (`int`, *optional*):
|
||||
Hidden dim of the MLP if we encode the prefix.
|
||||
vocab_size (`int`, *optional*, defaults to 50257):
|
||||
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
|
||||
n_positions (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
n_embd (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
n_layer (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
n_head (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
n_inner (`int`, *optional*, defaults to None):
|
||||
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
|
||||
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
||||
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
embd_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the embeddings.
|
||||
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon to use in the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
||||
Scale attention weights by dividing by sqrt(hidden_size)..
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
|
||||
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
|
||||
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
||||
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
|
||||
dot-product/softmax to float() when training with mixed precision.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,17 +84,6 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
scale_attn_by_inverse_layer_idx: bool = False,
|
||||
reorder_and_upcast_attn: bool = False,
|
||||
):
|
||||
"""
|
||||
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
|
||||
generate text from the UniDiffuser image-text embedding.
|
||||
|
||||
Parameters:
|
||||
prefix_length (`int`):
|
||||
Max number of prefix tokens that will be supplied to the model.
|
||||
prefix_hidden_dim (`int`, *optional*):
|
||||
Hidden dim of the MLP if we encode the prefix.
|
||||
TODO: add GPT2 config args
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.prefix_length = prefix_length
|
||||
@@ -104,7 +142,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*):
|
||||
Attention mask for the prefix embedding.
|
||||
labels (`torch.Tensor`, *optional*):
|
||||
TODO
|
||||
Labels to use for language modeling.
|
||||
"""
|
||||
embedding_text = self.transformer.transformer.wte(tokens)
|
||||
hidden = self.encode_prefix(prefix)
|
||||
@@ -131,12 +169,15 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
Args:
|
||||
tokenizer (`GPT2Tokenizer`):
|
||||
Tokenizer of class
|
||||
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for
|
||||
tokenizing input to the text decoder model.
|
||||
[`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer)
|
||||
for tokenizing input to the text decoder model.
|
||||
features (`torch.Tensor` of shape `(B, L, D)`):
|
||||
Text embedding features to generate captions from.
|
||||
device:
|
||||
Device to perform text generation on.
|
||||
|
||||
Returns:
|
||||
`List[str]`: A list of strings generated from the decoder model.
|
||||
"""
|
||||
|
||||
features = torch.split(features, 1, dim=0)
|
||||
@@ -157,12 +198,33 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
beam_size: int = 5,
|
||||
entry_length: int = 67,
|
||||
temperature: float = 1.0,
|
||||
stop_token: str = "<|EOS|>",
|
||||
):
|
||||
"""
|
||||
Generates text using the given tokenizer and text prompt or token embedding via beam search.
|
||||
Generates text using the given tokenizer and text prompt or token embedding via beam search. This
|
||||
implementation is based on the beam search implementation from the [original UniDiffuser
|
||||
code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89).
|
||||
|
||||
TODO: args
|
||||
Args:
|
||||
tokenizer (`GPT2Tokenizer`):
|
||||
Tokenizer of class
|
||||
[`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer)
|
||||
for tokenizing input to the text decoder model.
|
||||
prompt (`str`, *optional*):
|
||||
A raw text prompt to use as the prefix for beam search. One of `prompt` and `embed` must be supplied.
|
||||
embed (`torch.Tensor` of shape `(B, L, D)`, *optional*):
|
||||
An embedded representation to directly pass to the transformer as a perfix for beam search. One of
|
||||
`prompt` and `embed` must be supplied.
|
||||
device:
|
||||
The device to perform beam search on.
|
||||
beam_size (`int`, *optional*, defaults to `5`):
|
||||
The number of best states to store during beam search.
|
||||
entry_length (`int`, *optional*, defaults to `67`):
|
||||
The number of iterations to run beam search.
|
||||
temperature (`float`, *optional*, defaults to 1.0):
|
||||
The temperature to use when performing the softmax over logits from the decoding model.
|
||||
|
||||
Returns:
|
||||
`List[str]`: A list of strings generated from the decoder model via beam search.
|
||||
"""
|
||||
# Generates text until stop_token is reached using beam search with the desired beam size.
|
||||
stop_token_index = tokenizer.eos_token_id
|
||||
@@ -219,10 +281,6 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
|
||||
scores = scores / seq_lengths
|
||||
output_list = tokens.cpu().numpy()
|
||||
# print(f"Output list: {output_list}")
|
||||
# print(f"Output list length: {len(output_list)}")
|
||||
# print(f"Seq lengths: {seq_lengths}")
|
||||
# print(f"Seq lengths length: {len(seq_lengths)}")
|
||||
output_texts = [
|
||||
tokenizer.decode(output[: int(length)], skip_special_tokens=True)
|
||||
for output, length in zip(output_list, seq_lengths)
|
||||
|
||||
@@ -147,15 +147,28 @@ class UTransformerBlock(nn.Module):
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:obj: `int`, *optional*):
|
||||
The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:obj: `bool`, *optional*, defaults to `False`):
|
||||
Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the query and key to float32 when performing the attention calculation.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Whether to use learnable per-element affine parameters during layer normalization.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
The layer norm implementation to use.
|
||||
pre_layer_norm (`bool`, *optional*):
|
||||
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
|
||||
as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g.
|
||||
`pre_layer_norm = True`.
|
||||
final_dropout (`bool`, *optional*):
|
||||
Whether to use a final Dropout layer after the feedforward network.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -323,7 +336,8 @@ class UTransformerBlock(nn.Module):
|
||||
class UniDiffuserBlock(nn.Module):
|
||||
r"""
|
||||
A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the
|
||||
LayerNorms on the residual backbone of the block.
|
||||
LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser
|
||||
implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104).
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
@@ -331,15 +345,28 @@ class UniDiffuserBlock(nn.Module):
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:obj: `int`, *optional*):
|
||||
The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:obj: `bool`, *optional*, defaults to `False`):
|
||||
Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the query and key to float() when performing the attention calculation.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Whether to use learnable per-element affine parameters during layer normalization.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
The layer norm implementation to use.
|
||||
pre_layer_norm (`bool`, *optional*):
|
||||
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
|
||||
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
|
||||
(`pre_layer_norm = False`).
|
||||
final_dropout (`bool`, *optional*):
|
||||
Whether to use a final Dropout layer after the feedforward network.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -510,39 +537,58 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared
|
||||
to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion,
|
||||
similar to a U-Net. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.
|
||||
|
||||
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
||||
transformer action. Finally, reshape to image.
|
||||
|
||||
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
||||
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
||||
classes of unnoised image.
|
||||
|
||||
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
||||
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
||||
similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`]
|
||||
layer and then reshaped to (b, t, d).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
Pass if the input is continuous. The number of channels in the input.
|
||||
out_channels (`int`, *optional*):
|
||||
The number of output channels; if `None`, defaults to `in_channels`.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use when performing Group Normalization.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
patch_size (`int`, *optional*, defaults to 2):
|
||||
The patch size to use in the patch embedding.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
use_linear_projection (int, *optional*): TODO: Not used
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used in each
|
||||
transformer block.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the query and key to float() when performing the attention calculation.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
|
||||
block_type (`str`, *optional*, defaults to `"unidiffuser"`):
|
||||
The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
|
||||
backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
|
||||
behavior in `diffusers`.)
|
||||
pre_layer_norm (`bool`, *optional*):
|
||||
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
|
||||
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
|
||||
(`pre_layer_norm = False`).
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Whether to use learnable per-element affine parameters during layer normalization.
|
||||
use_patch_pos_embed (`bool`, *optional*):
|
||||
Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
|
||||
final_dropout (`bool`, *optional*):
|
||||
Whether to use a final Dropout layer after the feedforward network.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -559,7 +605,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
patch_size: Optional[int] = 2,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
@@ -709,12 +755,16 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
||||
conditioning.
|
||||
cross_attention_kwargs (*optional*):
|
||||
Keyword arguments to supply to the cross attention layers, if used.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
|
||||
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
|
||||
transformer blocks.
|
||||
unpatchify (`bool`, *optional*, defaults to `True`):
|
||||
Whether to unpatchify the transformer output.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
@@ -799,23 +849,56 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
Pass if the input is continuous. The number of channels in the input.
|
||||
out_channels (`int`, *optional*):
|
||||
The number of output channels; if `None`, defaults to `in_channels`.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use when performing Group Normalization.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
patch_size (`int`, *optional*, defaults to 2):
|
||||
The patch size to use in the patch embedding.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
use_linear_projection (int, *optional*): TODO: Not used
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used in each
|
||||
transformer block.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the query and key to float32 when performing the attention calculation.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
|
||||
block_type (`str`, *optional*, defaults to `"unidiffuser"`):
|
||||
The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
|
||||
backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
|
||||
behavior in `diffusers`.)
|
||||
pre_layer_norm (`bool`, *optional*):
|
||||
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
|
||||
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
|
||||
(`pre_layer_norm = False`).
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Whether to use learnable per-element affine parameters during layer normalization.
|
||||
use_patch_pos_embed (`bool`, *optional*):
|
||||
Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
|
||||
ff_final_dropout (`bool`, *optional*):
|
||||
Whether to use a final Dropout layer after the feedforward network.
|
||||
use_data_type_embedding (`bool`, *optional*):
|
||||
Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1
|
||||
is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type`
|
||||
argument, which can either be `1` to use the weights trained on non-publically-available data or `0`
|
||||
otherwise. This argument is subsequently embedded by the data type embedding, if used.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -984,9 +1067,9 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
|
||||
Current denoising step for the image.
|
||||
t_text (`torch.long` or `float` or `int`):
|
||||
Current denoising step for the text.
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`):
|
||||
Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data,
|
||||
or `0` otherwise.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
@@ -995,13 +1078,14 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
||||
conditioning.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
cross_attention_kwargs (*optional*):
|
||||
Keyword arguments to supply to the cross attention layers, if used.
|
||||
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
`tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE
|
||||
image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text
|
||||
embedding.
|
||||
"""
|
||||
batch_size = img_vae.shape[0]
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class ImageTextPipelineOutput(BaseOutput):
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
text (`str` or `List[str]`)
|
||||
text (`List[str]` or `List[List[str]]`)
|
||||
List of generated text strings of length `batch_size` or a list of list of strings whose outer list has
|
||||
length `batch_size`. Text generated by the diffusion pipeline.
|
||||
"""
|
||||
@@ -89,29 +89,33 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
is part of the UniDiffuser image representation, along with the CLIP vision encoding.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
text_decoder ([`UniDiffuserModel`]):
|
||||
Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser
|
||||
embedding.
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text
|
||||
prompts.
|
||||
image_encoder ([`CLIPVisionModel`]):
|
||||
UniDiffuser uses the vision portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode
|
||||
images as part of its image representation, along with the VAE latent representation.
|
||||
tokenizer ([`CLIPTokenizer`]):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
image_processor ([`CLIPImageProcessor`]):
|
||||
CLIP image process of class
|
||||
[CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor),
|
||||
used to preprocess the image before CLIP encoding.
|
||||
CLIP image processor of class
|
||||
[`CLIPImageProcessor`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor),
|
||||
used to preprocess the image before CLIP encoding it with `image_encoder`.
|
||||
clip_tokenizer ([`CLIPTokenizer`]):
|
||||
Tokenizer of class
|
||||
[`CLIPTokenizer`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which
|
||||
is used to tokenizer a prompt before encoding it with `text_encoder`.
|
||||
text_decoder ([`UniDiffuserTextDecoder`]):
|
||||
Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser
|
||||
embedding.
|
||||
text_tokenizer ([`GPT2Tokenizer`]):
|
||||
Tokenizer of class
|
||||
[`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which
|
||||
is used along with the `text_decoder` to decode text for text generation.
|
||||
unet ([`UniDiffuserModel`]):
|
||||
UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is like a
|
||||
[`Transformer2DModel`] with U-Net-style skip connections. [`UniDiffuserModel`] additionally has input and
|
||||
output heads on top of a base [`UTransformer2DModel`].
|
||||
UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a
|
||||
[`Transformer2DModel`] with U-Net-style skip connections between transformer layers.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The
|
||||
original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -258,7 +262,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
return extra_step_kwargs
|
||||
|
||||
def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents):
|
||||
r"""Infer the mode from the inputs to `__call__`."""
|
||||
r"""
|
||||
Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set
|
||||
mode will be used.
|
||||
"""
|
||||
prompt_available = (prompt is not None) or (prompt_embeds is not None)
|
||||
image_available = image is not None
|
||||
input_available = prompt_available or image_available
|
||||
@@ -314,21 +321,27 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
# Functions to manually set the mode
|
||||
def set_text_mode(self):
|
||||
r"""Manually set the generation mode to unconditional ("marginal") text generation."""
|
||||
self.mode = "text"
|
||||
|
||||
def set_image_mode(self):
|
||||
r"""Manually set the generation mode to unconditional ("marginal") image generation."""
|
||||
self.mode = "img"
|
||||
|
||||
def set_text_to_image_mode(self):
|
||||
r"""Manually set the generation mode to text-conditioned image generation."""
|
||||
self.mode = "text2img"
|
||||
|
||||
def set_image_to_text_mode(self):
|
||||
r"""Manually set the generation mode to image-conditioned text generation."""
|
||||
self.mode = "img2text"
|
||||
|
||||
def set_joint_mode(self):
|
||||
r"""Manually set the generation mode to unconditional joint image-text generation."""
|
||||
self.mode = "joint"
|
||||
|
||||
def reset_mode(self):
|
||||
r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs."""
|
||||
self.mode = None
|
||||
|
||||
def _infer_batch_size(
|
||||
@@ -344,7 +357,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
vae_latents,
|
||||
clip_latents,
|
||||
):
|
||||
r"""Infers the batch size depending on mode."""
|
||||
r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`."""
|
||||
if num_images_per_prompt is None:
|
||||
num_images_per_prompt = 1
|
||||
if num_prompts_per_image is None:
|
||||
@@ -688,7 +701,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
return latents
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
# Rename: prepare_latents -> prepare_image_vae_latents
|
||||
# Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument.
|
||||
def prepare_image_vae_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -817,7 +830,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
height,
|
||||
width,
|
||||
):
|
||||
# Predicts noise using the noise prediction model for the given mode.
|
||||
r"""
|
||||
Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary.
|
||||
"""
|
||||
if mode == "joint":
|
||||
# Joint text-image generation
|
||||
img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width)
|
||||
@@ -1044,7 +1059,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
width: Optional[int] = None,
|
||||
data_type: Optional[int] = 1,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_scale: float = 8.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
num_prompts_per_image: Optional[int] = 1,
|
||||
@@ -1066,14 +1081,14 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead. Used in `text2img` mode.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, that will be used when performing image-conditioned
|
||||
text generation (`img2text`).
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead. Required for text-conditioned image generation (`text2img`) mode.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
||||
`Image`, or tensor representing an image batch. Required for image-conditioned text generation
|
||||
(`img2text`) mode.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
data_type (`int`, *optional*, defaults to 1):
|
||||
The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type
|
||||
@@ -1086,11 +1101,13 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
usually at the expense of lower image quality. Note that the original [UniDiffuser
|
||||
paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of guidance scale `w'`, which
|
||||
satisfies `w = w' + 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`). Used in text-conditioned image generation (`text2img` mode).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and
|
||||
`img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are
|
||||
@@ -1125,17 +1142,17 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned
|
||||
image generation (`text2img`) mode.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
argument. Used in text-conditioned image generation (`text2img`) mode.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
@@ -1144,7 +1161,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
called at every step.
|
||||
Examples: Returns:
|
||||
[`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`:
|
||||
[`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
[`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
||||
of generated texts.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user