1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add/edit docstrings for added classes and public pipeline methods. Also do some light refactoring.

This commit is contained in:
Daniel Gu
2023-05-10 19:14:20 -07:00
parent 5728328545
commit abd6fca81e
3 changed files with 254 additions and 95 deletions

View File

@@ -12,6 +12,55 @@ from ...models import ModelMixin
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
"""
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
generate text from the UniDiffuser image-text embedding.
Parameters:
prefix_length (`int`):
Max number of prefix tokens that will be supplied to the model.
prefix_inner_dim (`int`):
The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
CLIP text encoder.
prefix_hidden_dim (`int`, *optional*):
Hidden dim of the MLP if we encode the prefix.
vocab_size (`int`, *optional*, defaults to 50257):
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
n_positions (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
n_embd (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
n_layer (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
n_head (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
n_inner (`int`, *optional*, defaults to None):
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
activation_function (`str`, *optional*, defaults to `"gelu"`):
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
resid_pdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the embeddings.
attn_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_attn_weights (`bool`, *optional*, defaults to `True`):
Scale attention weights by dividing by sqrt(hidden_size)..
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
dot-product/softmax to float() when training with mixed precision.
"""
@register_to_config
def __init__(
self,
@@ -35,17 +84,6 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
scale_attn_by_inverse_layer_idx: bool = False,
reorder_and_upcast_attn: bool = False,
):
"""
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
generate text from the UniDiffuser image-text embedding.
Parameters:
prefix_length (`int`):
Max number of prefix tokens that will be supplied to the model.
prefix_hidden_dim (`int`, *optional*):
Hidden dim of the MLP if we encode the prefix.
TODO: add GPT2 config args
"""
super().__init__()
self.prefix_length = prefix_length
@@ -104,7 +142,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*):
Attention mask for the prefix embedding.
labels (`torch.Tensor`, *optional*):
TODO
Labels to use for language modeling.
"""
embedding_text = self.transformer.transformer.wte(tokens)
hidden = self.encode_prefix(prefix)
@@ -131,12 +169,15 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
Args:
tokenizer (`GPT2Tokenizer`):
Tokenizer of class
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for
tokenizing input to the text decoder model.
[`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer)
for tokenizing input to the text decoder model.
features (`torch.Tensor` of shape `(B, L, D)`):
Text embedding features to generate captions from.
device:
Device to perform text generation on.
Returns:
`List[str]`: A list of strings generated from the decoder model.
"""
features = torch.split(features, 1, dim=0)
@@ -157,12 +198,33 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
beam_size: int = 5,
entry_length: int = 67,
temperature: float = 1.0,
stop_token: str = "<|EOS|>",
):
"""
Generates text using the given tokenizer and text prompt or token embedding via beam search.
Generates text using the given tokenizer and text prompt or token embedding via beam search. This
implementation is based on the beam search implementation from the [original UniDiffuser
code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89).
TODO: args
Args:
tokenizer (`GPT2Tokenizer`):
Tokenizer of class
[`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer)
for tokenizing input to the text decoder model.
prompt (`str`, *optional*):
A raw text prompt to use as the prefix for beam search. One of `prompt` and `embed` must be supplied.
embed (`torch.Tensor` of shape `(B, L, D)`, *optional*):
An embedded representation to directly pass to the transformer as a perfix for beam search. One of
`prompt` and `embed` must be supplied.
device:
The device to perform beam search on.
beam_size (`int`, *optional*, defaults to `5`):
The number of best states to store during beam search.
entry_length (`int`, *optional*, defaults to `67`):
The number of iterations to run beam search.
temperature (`float`, *optional*, defaults to 1.0):
The temperature to use when performing the softmax over logits from the decoding model.
Returns:
`List[str]`: A list of strings generated from the decoder model via beam search.
"""
# Generates text until stop_token is reached using beam search with the desired beam size.
stop_token_index = tokenizer.eos_token_id
@@ -219,10 +281,6 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
# print(f"Output list: {output_list}")
# print(f"Output list length: {len(output_list)}")
# print(f"Seq lengths: {seq_lengths}")
# print(f"Seq lengths length: {len(seq_lengths)}")
output_texts = [
tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)

View File

@@ -147,15 +147,28 @@ class UTransformerBlock(nn.Module):
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward.
num_embeds_ada_norm (:obj: `int`, *optional*):
The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:obj: `bool`, *optional*, defaults to `False`):
Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
upcast_attention (`bool`, *optional*):
Whether to upcast the query and key to float32 when performing the attention calculation.
norm_elementwise_affine (`bool`, *optional*):
Whether to use learnable per-element affine parameters during layer normalization.
norm_type (`str`, defaults to `"layer_norm"`):
The layer norm implementation to use.
pre_layer_norm (`bool`, *optional*):
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g.
`pre_layer_norm = True`.
final_dropout (`bool`, *optional*):
Whether to use a final Dropout layer after the feedforward network.
"""
def __init__(
@@ -323,7 +336,8 @@ class UTransformerBlock(nn.Module):
class UniDiffuserBlock(nn.Module):
r"""
A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the
LayerNorms on the residual backbone of the block.
LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser
implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104).
Parameters:
dim (`int`): The number of channels in the input and output.
@@ -331,15 +345,28 @@ class UniDiffuserBlock(nn.Module):
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward.
num_embeds_ada_norm (:obj: `int`, *optional*):
The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:obj: `bool`, *optional*, defaults to `False`):
Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
upcast_attention (`bool`, *optional*):
Whether to upcast the query and key to float() when performing the attention calculation.
norm_elementwise_affine (`bool`, *optional*):
Whether to use learnable per-element affine parameters during layer normalization.
norm_type (`str`, defaults to `"layer_norm"`):
The layer norm implementation to use.
pre_layer_norm (`bool`, *optional*):
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
(`pre_layer_norm = False`).
final_dropout (`bool`, *optional*):
Whether to use a final Dropout layer after the feedforward network.
"""
def __init__(
@@ -510,39 +537,58 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared
to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion,
similar to a U-Net. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`]
layer and then reshaped to (b, t, d).
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
Pass if the input is continuous. The number of channels in the input.
out_channels (`int`, *optional*):
The number of output channels; if `None`, defaults to `in_channels`.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
norm_num_groups (`int`, *optional*, defaults to `32`):
The number of groups to use when performing Group Normalization.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
patch_size (`int`, *optional*, defaults to 2):
The patch size to use in the patch embedding.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
use_linear_projection (int, *optional*): TODO: Not used
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used in each
transformer block.
upcast_attention (`bool`, *optional*):
Whether to upcast the query and key to float() when performing the attention calculation.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
block_type (`str`, *optional*, defaults to `"unidiffuser"`):
The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
behavior in `diffusers`.)
pre_layer_norm (`bool`, *optional*):
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
(`pre_layer_norm = False`).
norm_elementwise_affine (`bool`, *optional*):
Whether to use learnable per-element affine parameters during layer normalization.
use_patch_pos_embed (`bool`, *optional*):
Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
final_dropout (`bool`, *optional*):
Whether to use a final Dropout layer after the feedforward network.
"""
@register_to_config
@@ -559,7 +605,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
patch_size: Optional[int] = 2,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
@@ -709,12 +755,16 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
cross_attention_kwargs (*optional*):
Keyword arguments to supply to the cross attention layers, if used.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
transformer blocks.
unpatchify (`bool`, *optional*, defaults to `True`):
Whether to unpatchify the transformer output.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
@@ -799,23 +849,56 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
Pass if the input is continuous. The number of channels in the input.
out_channels (`int`, *optional*):
The number of output channels; if `None`, defaults to `in_channels`.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
norm_num_groups (`int`, *optional*, defaults to `32`):
The number of groups to use when performing Group Normalization.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
patch_size (`int`, *optional*, defaults to 2):
The patch size to use in the patch embedding.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
use_linear_projection (int, *optional*): TODO: Not used
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used in each
transformer block.
upcast_attention (`bool`, *optional*):
Whether to upcast the query and key to float32 when performing the attention calculation.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
block_type (`str`, *optional*, defaults to `"unidiffuser"`):
The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
behavior in `diffusers`.)
pre_layer_norm (`bool`, *optional*):
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
(`pre_layer_norm = False`).
norm_elementwise_affine (`bool`, *optional*):
Whether to use learnable per-element affine parameters during layer normalization.
use_patch_pos_embed (`bool`, *optional*):
Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
ff_final_dropout (`bool`, *optional*):
Whether to use a final Dropout layer after the feedforward network.
use_data_type_embedding (`bool`, *optional*):
Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1
is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type`
argument, which can either be `1` to use the weights trained on non-publically-available data or `0`
otherwise. This argument is subsequently embedded by the data type embedding, if used.
"""
@register_to_config
@@ -984,9 +1067,9 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
Current denoising step for the image.
t_text (`torch.long` or `float` or `int`):
Current denoising step for the text.
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`):
Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data,
or `0` otherwise.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
@@ -995,13 +1078,14 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (*optional*):
Keyword arguments to supply to the cross attention layers, if used.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
`tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE
image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text
embedding.
"""
batch_size = img_vae.shape[0]

View File

@@ -65,7 +65,7 @@ class ImageTextPipelineOutput(BaseOutput):
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
text (`str` or `List[str]`)
text (`List[str]` or `List[List[str]]`)
List of generated text strings of length `batch_size` or a list of list of strings whose outer list has
length `batch_size`. Text generated by the diffusion pipeline.
"""
@@ -89,29 +89,33 @@ class UniDiffuserPipeline(DiffusionPipeline):
is part of the UniDiffuser image representation, along with the CLIP vision encoding.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_decoder ([`UniDiffuserModel`]):
Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser
embedding.
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text
prompts.
image_encoder ([`CLIPVisionModel`]):
UniDiffuser uses the vision portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode
images as part of its image representation, along with the VAE latent representation.
tokenizer ([`CLIPTokenizer`]):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
image_processor ([`CLIPImageProcessor`]):
CLIP image process of class
[CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor),
used to preprocess the image before CLIP encoding.
CLIP image processor of class
[`CLIPImageProcessor`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor),
used to preprocess the image before CLIP encoding it with `image_encoder`.
clip_tokenizer ([`CLIPTokenizer`]):
Tokenizer of class
[`CLIPTokenizer`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which
is used to tokenizer a prompt before encoding it with `text_encoder`.
text_decoder ([`UniDiffuserTextDecoder`]):
Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser
embedding.
text_tokenizer ([`GPT2Tokenizer`]):
Tokenizer of class
[`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which
is used along with the `text_decoder` to decode text for text generation.
unet ([`UniDiffuserModel`]):
UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is like a
[`Transformer2DModel`] with U-Net-style skip connections. [`UniDiffuserModel`] additionally has input and
output heads on top of a base [`UTransformer2DModel`].
UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a
[`Transformer2DModel`] with U-Net-style skip connections between transformer layers.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The
original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler.
"""
def __init__(
@@ -258,7 +262,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
return extra_step_kwargs
def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents):
r"""Infer the mode from the inputs to `__call__`."""
r"""
Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set
mode will be used.
"""
prompt_available = (prompt is not None) or (prompt_embeds is not None)
image_available = image is not None
input_available = prompt_available or image_available
@@ -314,21 +321,27 @@ class UniDiffuserPipeline(DiffusionPipeline):
# Functions to manually set the mode
def set_text_mode(self):
r"""Manually set the generation mode to unconditional ("marginal") text generation."""
self.mode = "text"
def set_image_mode(self):
r"""Manually set the generation mode to unconditional ("marginal") image generation."""
self.mode = "img"
def set_text_to_image_mode(self):
r"""Manually set the generation mode to text-conditioned image generation."""
self.mode = "text2img"
def set_image_to_text_mode(self):
r"""Manually set the generation mode to image-conditioned text generation."""
self.mode = "img2text"
def set_joint_mode(self):
r"""Manually set the generation mode to unconditional joint image-text generation."""
self.mode = "joint"
def reset_mode(self):
r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs."""
self.mode = None
def _infer_batch_size(
@@ -344,7 +357,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
vae_latents,
clip_latents,
):
r"""Infers the batch size depending on mode."""
r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`."""
if num_images_per_prompt is None:
num_images_per_prompt = 1
if num_prompts_per_image is None:
@@ -688,7 +701,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
return latents
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
# Rename: prepare_latents -> prepare_image_vae_latents
# Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument.
def prepare_image_vae_latents(
self,
batch_size,
@@ -817,7 +830,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
height,
width,
):
# Predicts noise using the noise prediction model for the given mode.
r"""
Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary.
"""
if mode == "joint":
# Joint text-image generation
img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width)
@@ -1044,7 +1059,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
width: Optional[int] = None,
data_type: Optional[int] = 1,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
num_prompts_per_image: Optional[int] = 1,
@@ -1066,14 +1081,14 @@ class UniDiffuserPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. Used in `text2img` mode.
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used when performing image-conditioned
text generation (`img2text`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
instead. Required for text-conditioned image generation (`text2img`) mode.
image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch. Required for image-conditioned text generation
(`img2text`) mode.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
data_type (`int`, *optional*, defaults to 1):
The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type
@@ -1086,11 +1101,13 @@ class UniDiffuserPipeline(DiffusionPipeline):
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
usually at the expense of lower image quality. Note that the original [UniDiffuser
paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of guidance scale `w'`, which
satisfies `w = w' + 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). Used in text-conditioned image generation (`text2img` mode).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and
`img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are
@@ -1125,17 +1142,17 @@ class UniDiffuserPipeline(DiffusionPipeline):
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned
image generation (`text2img`) mode.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
argument. Used in text-conditioned image generation (`text2img`) mode.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
@@ -1144,7 +1161,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
called at every step.
Examples: Returns:
[`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`:
[`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
[`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images, and the second element is a list
of generated texts.
"""