From abd6fca81ea7be9882b1fea7a52ba3890c8a1946 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 10 May 2023 19:14:20 -0700 Subject: [PATCH] Add/edit docstrings for added classes and public pipeline methods. Also do some light refactoring. --- .../unidiffuser/modeling_text_decoder.py | 100 ++++++++--- .../pipelines/unidiffuser/modeling_uvit.py | 158 ++++++++++++++---- .../unidiffuser/pipeline_unidiffuser.py | 91 ++++++---- 3 files changed, 254 insertions(+), 95 deletions(-) diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py index ac65c46d6d..6da23abd54 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -12,6 +12,55 @@ from ...models import ModelMixin # Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + """ + Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to + generate text from the UniDiffuser image-text embedding. + + Parameters: + prefix_length (`int`): + Max number of prefix tokens that will be supplied to the model. + prefix_inner_dim (`int`): + The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the + CLIP text encoder. + prefix_hidden_dim (`int`, *optional*): + Hidden dim of the MLP if we encode the prefix. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + """ + @register_to_config def __init__( self, @@ -35,17 +84,6 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): scale_attn_by_inverse_layer_idx: bool = False, reorder_and_upcast_attn: bool = False, ): - """ - Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to - generate text from the UniDiffuser image-text embedding. - - Parameters: - prefix_length (`int`): - Max number of prefix tokens that will be supplied to the model. - prefix_hidden_dim (`int`, *optional*): - Hidden dim of the MLP if we encode the prefix. - TODO: add GPT2 config args - """ super().__init__() self.prefix_length = prefix_length @@ -104,7 +142,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): Attention mask for the prefix embedding. labels (`torch.Tensor`, *optional*): - TODO + Labels to use for language modeling. """ embedding_text = self.transformer.transformer.wte(tokens) hidden = self.encode_prefix(prefix) @@ -131,12 +169,15 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): Args: tokenizer (`GPT2Tokenizer`): Tokenizer of class - [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for - tokenizing input to the text decoder model. + [`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) + for tokenizing input to the text decoder model. features (`torch.Tensor` of shape `(B, L, D)`): Text embedding features to generate captions from. device: Device to perform text generation on. + + Returns: + `List[str]`: A list of strings generated from the decoder model. """ features = torch.split(features, 1, dim=0) @@ -157,12 +198,33 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): beam_size: int = 5, entry_length: int = 67, temperature: float = 1.0, - stop_token: str = "<|EOS|>", ): """ - Generates text using the given tokenizer and text prompt or token embedding via beam search. + Generates text using the given tokenizer and text prompt or token embedding via beam search. This + implementation is based on the beam search implementation from the [original UniDiffuser + code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). - TODO: args + Args: + tokenizer (`GPT2Tokenizer`): + Tokenizer of class + [`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) + for tokenizing input to the text decoder model. + prompt (`str`, *optional*): + A raw text prompt to use as the prefix for beam search. One of `prompt` and `embed` must be supplied. + embed (`torch.Tensor` of shape `(B, L, D)`, *optional*): + An embedded representation to directly pass to the transformer as a perfix for beam search. One of + `prompt` and `embed` must be supplied. + device: + The device to perform beam search on. + beam_size (`int`, *optional*, defaults to `5`): + The number of best states to store during beam search. + entry_length (`int`, *optional*, defaults to `67`): + The number of iterations to run beam search. + temperature (`float`, *optional*, defaults to 1.0): + The temperature to use when performing the softmax over logits from the decoding model. + + Returns: + `List[str]`: A list of strings generated from the decoder model via beam search. """ # Generates text until stop_token is reached using beam search with the desired beam size. stop_token_index = tokenizer.eos_token_id @@ -219,10 +281,6 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): scores = scores / seq_lengths output_list = tokens.cpu().numpy() - # print(f"Output list: {output_list}") - # print(f"Output list length: {len(output_list)}") - # print(f"Seq lengths: {seq_lengths}") - # print(f"Seq lengths length: {len(seq_lengths)}") output_texts = [ tokenizer.decode(output[: int(length)], skip_special_tokens=True) for output, length in zip(output_list, seq_lengths) diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index 6ac3ee20eb..59c364c4e6 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -147,15 +147,28 @@ class UTransformerBlock(nn.Module): attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. + `pre_layer_norm = True`. + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. """ def __init__( @@ -323,7 +336,8 @@ class UTransformerBlock(nn.Module): class UniDiffuserBlock(nn.Module): r""" A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the - LayerNorms on the residual backbone of the block. + LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser + implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). Parameters: dim (`int`): The number of channels in the input and output. @@ -331,15 +345,28 @@ class UniDiffuserBlock(nn.Module): attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. """ def __init__( @@ -510,39 +537,58 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): """ Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, - similar to a U-Net. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs. - - When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard - transformer action. Finally, reshape to image. - - When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional - embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict - classes of unnoised image. - - Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised - image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] + layer and then reshaped to (b, t, d). Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. """ @register_to_config @@ -559,7 +605,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, + patch_size: Optional[int] = 2, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, @@ -709,12 +755,16 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the transformer blocks. + unpatchify (`bool`, *optional*, defaults to `True`): + Whether to unpatchify the transformer output. Returns: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: @@ -799,23 +849,56 @@ class UniDiffuserModel(ModelMixin, ConfigMixin): num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + ff_final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + use_data_type_embedding (`bool`, *optional*): + Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 + is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` + argument, which can either be `1` to use the weights trained on non-publically-available data or `0` + otherwise. This argument is subsequently embedded by the data type embedding, if used. """ @register_to_config @@ -984,9 +1067,9 @@ class UniDiffuserModel(ModelMixin, ConfigMixin): Current denoising step for the image. t_text (`torch.long` or `float` or `int`): Current denoising step for the text. - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states + data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): + Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, + or `0` otherwise. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. @@ -995,13 +1078,14 @@ class UniDiffuserModel(ModelMixin, ConfigMixin): class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE + image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text + embedding. """ batch_size = img_vae.shape[0] diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index a3206cd331..928eabb222 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -65,7 +65,7 @@ class ImageTextPipelineOutput(BaseOutput): images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - text (`str` or `List[str]`) + text (`List[str]` or `List[List[str]]`) List of generated text strings of length `batch_size` or a list of list of strings whose outer list has length `batch_size`. Text generated by the diffusion pipeline. """ @@ -89,29 +89,33 @@ class UniDiffuserPipeline(DiffusionPipeline): is part of the UniDiffuser image representation, along with the CLIP vision encoding. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_decoder ([`UniDiffuserModel`]): - Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser - embedding. + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text + prompts. image_encoder ([`CLIPVisionModel`]): UniDiffuser uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode images as part of its image representation, along with the VAE latent representation. - tokenizer ([`CLIPTokenizer`]): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). image_processor ([`CLIPImageProcessor`]): - CLIP image process of class - [CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor), - used to preprocess the image before CLIP encoding. + CLIP image processor of class + [`CLIPImageProcessor`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor), + used to preprocess the image before CLIP encoding it with `image_encoder`. + clip_tokenizer ([`CLIPTokenizer`]): + Tokenizer of class + [`CLIPTokenizer`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which + is used to tokenizer a prompt before encoding it with `text_encoder`. + text_decoder ([`UniDiffuserTextDecoder`]): + Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser + embedding. + text_tokenizer ([`GPT2Tokenizer`]): + Tokenizer of class + [`GPT2Tokenizer`](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which + is used along with the `text_decoder` to decode text for text generation. unet ([`UniDiffuserModel`]): - UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is like a - [`Transformer2DModel`] with U-Net-style skip connections. [`UniDiffuserModel`] additionally has input and - output heads on top of a base [`UTransformer2DModel`]. + UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a + [`Transformer2DModel`] with U-Net-style skip connections between transformer layers. scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The + original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. """ def __init__( @@ -258,7 +262,10 @@ class UniDiffuserPipeline(DiffusionPipeline): return extra_step_kwargs def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): - r"""Infer the mode from the inputs to `__call__`.""" + r""" + Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set + mode will be used. + """ prompt_available = (prompt is not None) or (prompt_embeds is not None) image_available = image is not None input_available = prompt_available or image_available @@ -314,21 +321,27 @@ class UniDiffuserPipeline(DiffusionPipeline): # Functions to manually set the mode def set_text_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") text generation.""" self.mode = "text" def set_image_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") image generation.""" self.mode = "img" def set_text_to_image_mode(self): + r"""Manually set the generation mode to text-conditioned image generation.""" self.mode = "text2img" def set_image_to_text_mode(self): + r"""Manually set the generation mode to image-conditioned text generation.""" self.mode = "img2text" def set_joint_mode(self): + r"""Manually set the generation mode to unconditional joint image-text generation.""" self.mode = "joint" def reset_mode(self): + r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" self.mode = None def _infer_batch_size( @@ -344,7 +357,7 @@ class UniDiffuserPipeline(DiffusionPipeline): vae_latents, clip_latents, ): - r"""Infers the batch size depending on mode.""" + r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" if num_images_per_prompt is None: num_images_per_prompt = 1 if num_prompts_per_image is None: @@ -688,7 +701,7 @@ class UniDiffuserPipeline(DiffusionPipeline): return latents # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - # Rename: prepare_latents -> prepare_image_vae_latents + # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. def prepare_image_vae_latents( self, batch_size, @@ -817,7 +830,9 @@ class UniDiffuserPipeline(DiffusionPipeline): height, width, ): - # Predicts noise using the noise prediction model for the given mode. + r""" + Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. + """ if mode == "joint": # Joint text-image generation img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) @@ -1044,7 +1059,7 @@ class UniDiffuserPipeline(DiffusionPipeline): width: Optional[int] = None, data_type: Optional[int] = 1, num_inference_steps: int = 50, - guidance_scale: float = 7.5, + guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, num_prompts_per_image: Optional[int] = 1, @@ -1066,14 +1081,14 @@ class UniDiffuserPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. Used in `text2img` mode. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used when performing image-conditioned - text generation (`img2text`). - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. Required for text-conditioned image generation (`text2img`) mode. + image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch. Required for image-conditioned text generation + (`img2text`) mode. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. data_type (`int`, *optional*, defaults to 1): The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type @@ -1086,11 +1101,13 @@ class UniDiffuserPipeline(DiffusionPipeline): `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + usually at the expense of lower image quality. Note that the original [UniDiffuser + paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of guidance scale `w'`, which + satisfies `w = w' + 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). Used in text-conditioned image generation (`text2img` mode). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are @@ -1125,17 +1142,17 @@ class UniDiffuserPipeline(DiffusionPipeline): tensor will be generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. + provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned + image generation (`text2img`) mode. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + argument. Used in text-conditioned image generation (`text2img`) mode. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. + Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. @@ -1144,7 +1161,7 @@ class UniDiffuserPipeline(DiffusionPipeline): called at every step. Examples: Returns: [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: - [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images, and the second element is a list of generated texts. """