diff --git a/src/diffusers/pipelines/unidiffuser/__init__.py b/src/diffusers/pipelines/unidiffuser/__init__.py new file mode 100644 index 0000000000..335f696ffc --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/__init__.py @@ -0,0 +1,20 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + ImageTextPipelineOutput, + UniDiffuserPipeline, + ) +else: + from .modeling_text_decoder import UniDiffuserTextDecoder + from .modeling_uvit import UniDiffuserModel, UTransformer2DModel + from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py new file mode 100644 index 0000000000..6a6eea44a2 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -0,0 +1,188 @@ +from typing import Optional + +import numpy as np +import torch +from torch import nn +from transformers import GPT2LMHeadModel, GPT2Tokenizer + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py +class UniDiffuserTextDecoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + tokenizer: GPT2Tokenizer, + text_decoder: GPT2LMHeadModel, + prefix_length: int, + hidden_dim: Optional[int] = None, + use_hidden_dim: bool = True, + ): + """ + Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to + generate text from the UniDiffuser image-text embedding. + + Parameters: + tokenizer ([`GPT2Tokenizer`]): + Tokenizer of class + [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for + the GPT-like text decoder model. + text_decoder ([`GPT2LMHeadModel`]): + Text decoder model of class + [GPT2LMHeadModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel) + used to generate text from the UniDiffuser text embedding. + prefix_length (`int`): + TODO + hidden_dim (`int`, *optional*): + Hidden dim of the MLP if we encode the prefix. + use_hidden_dim (`bool`, *optional*, defaults to `True`): + Whether or not to use a MLP to encode the prefix. + """ + super().__init__() + self.prefix_length = prefix_length + + eos = "<|EOS|>" + special_tokens_dict = {"eos_token": eos} + self.tokenizer = tokenizer + self.tokenizer.add_special_tokens(special_tokens_dict) + + self.transformer = text_decoder + # TODO: need to set the eos_token_id correctly + self.transformer.config.eos_token_id = self.tokenizer.eos_token_id + self.transformer.resize_token_embeddings(len(self.tokenizer)) + + self.use_hidden_dim = use_hidden_dim + self.hidden_dim = hidden_dim if hidden_dim is not None else self.transformer.config.n_embd + self.encode_prefix = nn.Linear(768, self.hidden_dim) if use_hidden_dim else nn.Identity() + self.decode_prefix = nn.Linear(self.hidden_dim, 768) if use_hidden_dim else nn.Identity() + + def forward( + self, + tokens: torch.Tensor, + prefix: torch.Tensor, + mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + """ + Args: + tokens (`torch.Tensor` of shape `(N, max_seq_len)`): + Text tokens to use for inference. + prefix (`torch.Tensor` of shape `(N, prefix_length, 768)`): + Prefix embedding to preprend to the embedded tokens. + mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): + Attention mask for the prefix embedding. + labels (`torch.Tensor`, *optional*): + TODO + """ + embedding_text = self.transformer.transformer.wte(tokens) + hidden = self.encode_prefix(prefix) + prefix = self.decode_prefix(hidden) + embedding_cat = torch.cat((prefix, embedding_text), dim=1) + + if labels is not None: + dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) + labels = torch.cat((dummy_token, tokens), dim=1) + out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) + if self.use_hidden_dim: + return out, hidden + else: + return out + + def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + @torch.no_grad() + def generate_captions(self, features, device): + """ + Generate captions given text embedding features. Returns list[L]. + + Args: + features (`torch.Tensor` of shape `(B, L, D)`): + Text embedding features to generate captions from. + device: + Device to perform text generation on. + """ + + features = torch.split(features, 1, dim=0) + generated_captions = [] + for feature in features: + feature = self.decode_prefix(feature.to(device)) # back to the clip feature + # Only support beam search for now + generated_captions.append(self.generate_beam(embed=feature, device=device)[0]) + return generated_captions + + @torch.no_grad() + def generate_beam( + self, + prompt=None, + embed=None, + device=None, + beam_size: int = 5, + entry_length: int = 67, + temperature: float = 1.0, + stop_token: str = "<|EOS|>", + ): + # Generates text until stop_token is reached using beam search with the desired beam size. + stop_token_index = self.tokenizer.encode(stop_token)[0] + tokens = None + scores = None + seq_lengths = torch.ones(beam_size, device=device) + is_stopped = torch.zeros(beam_size, device=device, dtype=bool) + + if embed is not None: + generated = embed + else: + assert prompt is not None + tokens = torch.tensor(self.tokenizer.encode(prompt)) + tokens = tokens.unsqueeze(0).to(device) + generated = self.transformer.transformer.wte(tokens) + + for i in range(entry_length): + outputs = self.transformer(input_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + + next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + + scores = scores / seq_lengths + output_list = tokens.cpu().numpy() + output_texts = [ + self.tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + order = scores.argsort(descending=True) + output_texts = [output_texts[i] for i in order] + return output_texts diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py new file mode 100644 index 0000000000..dc53937b29 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -0,0 +1,635 @@ +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...models.attention import AdaLayerNorm, BasicTransformerBlock +from ...models.embeddings import ImagePositionalEmbeddings, PatchEmbed, TimestepEmbedding, Timesteps +from ...models.transformer_2d import Transformer2DModelOutput +from ...utils import deprecate + + +class SkipBlock(nn.Module): + def __init__(self, dim: int, num_embeds_ada_norm: Optional[int] = None): + super().__init__() + + self.skip_linear = nn.Linear(2 * dim, dim) + + # Use AdaLayerNorm for now, maybe support using other forms of LayerNorm? + # Original code uses torch.nn.LayerNorm + self.norm = AdaLayerNorm(dim, num_embeds_ada_norm) + + def forward(self, x, skip): + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = self.norm(x) + + return x + + +# Modified from diffusers.models.transformer_2d.Transformer2DModel +# Modify the transformer block structure to be U-Net like following U-ViT +# https://github.com/baofff/U-ViT +class UTransformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared + to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, + similar to a U-Net. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, + # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in + # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). + self.transformer_in_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ) + for d in range(num_layers // 2) + ] + ) + + self.transformer_mid_block = BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ) + + # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs + # before each transformer out_block. + self.transformer_out_blocks = nn.ModuleList( + [ + nn.ModuleDict( + { + "skip": SkipBlock( + inner_dim, + num_embeds_ada_norm=num_embeds_ada_norm, + ), + "block": BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ), + } + ) + for d in range(num_layers // 2) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + hidden_states_is_embedding: bool = False, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): + Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will + ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the + transformer blocks. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 1. Input + if not hidden_states_is_embedding: + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + + # In ("downsample") blocks + skips = [] + for in_block in self.transformer_in_blocks: + hidden_states = in_block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + skips.append(hidden_states) + + # Mid block + hidden_states = self.transformer_mid_block(hidden_states) + + # Out ("upsample") blocks + for out_block in self.transformer_in_blocks: + hidden_states = out_block["skip"](hidden_states, skips.pop()) + hidden_states = out_block["block"]( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class UniDiffuserModel(ModelMixin, ConfigMixin): + """ + Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a + modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the + CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). + + Parameters: + text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. + clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + text_dim: int = 768, + clip_img_dim: int = 512, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + ): + super().__init__() + + # 0. Handle dimensions + self.inner_dim = num_attention_heads * attention_head_dim + + assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.patch_size = patch_size + # Assume image is square... + self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) + + # 1. Define input layers + # For now, only support patch input for VAE latent image input + self.vae_img_in = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + ) + self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) + self.text_in = nn.Linear(text_dim, self.inner_dim) + + # Timestep embeddings for t_img, t_text + self.t_img_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.t_img_embed = TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + self.t_text_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.t_text_embed = TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + + # 2. Define transformer blocks + self.transformer = UTransformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + patch_size=patch_size, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + ) + + # 3. Define output layers + self.vae_img_out = nn.Linear(self.inner_dim, self.num_patches) + self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) + self.text_out = nn.Linear(self.inner_dim, text_dim) + + def forward( + self, + img_vae: torch.FloatTensor, + img_clip: torch.FloatTensor, + text: torch.FloatTensor, + t_img: Union[torch.Tensor, float, int], + t_text: Union[torch.Tensor, float, int], + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + img_vae (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`): + Latent image representation from the VAE encoder. + img_clip (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`): + CLIP-embedded image representation (unsqueezed in the first dimension). + text (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`): + CLIP-embedded text representation. + t_img (`torch.long` or `float` or `int`): + Current denoising step for the image. + t_text (`torch.long` or `float` or `int`): + Current denoising step for the text. + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + batch_size = img_vae.shape[0] + + # 1. Input + # 1.1. Map inputs to shape (B, N, inner_dim) + vae_hidden_states = self.vae_img_in(img_vae) + clip_hidden_states = self.clip_img_in(img_clip) + text_hidden_states = self.text_in(text) + + num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) + + # 1.2. Encode image and text timesteps + # t_img + if not torch.is_tensor(t_img): + t_img = torch.tensor([t_img], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + t_img = t_img * torch.ones(batch_size, dtype=t_img.dtype, device=t_img.device) + + t_img_token = self.t_img_proj(t_img) + # t_img_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + t_img_token = t_img_token.to(dtype=self.dtype) + t_img_token = self.t_img_embed(t_img_token) + t_img_token = t_img_token.unsqueeze(dim=1) + + # t_text + if not torch.is_tensor(t_text): + t_text = torch.tensor([t_text], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + t_text = t_text * torch.ones(batch_size, dtype=t_text.dtype, device=t_text.device) + + t_text_token = self.t_text_proj(t_text) + # t_text_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + t_text_token = t_text_token.to(dtype=self.dtype) + t_text_token = self.t_text_embed(t_text_token) + t_text_token = t_text_token.unsqueeze(dim=1) + + # 1.3. Concatenate all of the embeddings together. + hidden_states = torch.cat( + [t_img_token, t_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], dim=1 + ) + + # 2. Blocks + hidden_states = self.transformer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + class_labels=class_labels, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=return_dict, + hidden_states_is_embedding=True, + ) + + # 3. Output + # Split out the predicted noise representation. + t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( + (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 + ) + + img_vae_out = self.vae_img_out(img_vae_out) + # unpatchify + height = width = int(img_vae_out.shape[1] ** 0.5) + img_vae_out = img_vae_out.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) + img_vae_out = img_vae_out.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + img_clip_out = self.clip_img_out(img_clip_out) + + text_out = self.text_out(text_out) + + return img_vae_out, img_clip_out, text_out diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py new file mode 100644 index 0000000000..c31a26131c --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -0,0 +1,1025 @@ +import inspect +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import einops +import numpy as np +import PIL +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModel, +) + +from ...models import AutoencoderKL +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + logging, + randn_tensor, +) +from ...utils.outputs import BaseOutput +from ..pipeline_utils import DiffusionPipeline +from .modeling_text_decoder import UniDiffuserTextDecoder +from .modeling_uvit import UniDiffuserModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# New BaseOutput child class for joint image-text output +@dataclass +class ImageTextPipelineOutput(BaseOutput): + """ + Output class for joint image-text pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + text (`str` or `List[str]`) + List of generated text strings of length `batch_size` or a list of list of strings whose outer list has + length `batch_size`. Text generated by the diffusion pipeline. + """ + + images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + text: Optional[Union[List[str], List[List[str]]]] + + +class UniDiffuserPipeline(DiffusionPipeline): + r""" + Pipeline for a bimodal image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model, which supports + unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and + joint image-text generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. This + is part of the UniDiffuser image representation, along with the CLIP vision encoding. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_decoder ([`UniDiffuserModel`]): + Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser + embedding. + image_encoder ([`CLIPVisionModel`]): + UniDiffuser uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode + images as part of its image representation, along with the VAE latent representation. + tokenizer ([`CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + image_processor ([`CLIPImageProcessor`]): + CLIP image process of class + [CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor), + used to preprocess the image before CLIP encoding. + unet ([`UniDiffuserModel`]): + UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is like a + [`Transformer2DModel`] with U-Net-style skip connections. [`UniDiffuserModel`] additionally has input and + output heads on top of a base [`UTransformer2DModel`]. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_decoder: UniDiffuserTextDecoder, + image_encoder: CLIPVisionModel, + tokenizer: CLIPTokenizer, + image_processor: CLIPImageProcessor, + unet: UniDiffuserModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_decoder=text_decoder, + image_encoder=image_encoder, + tokenizer=tokenizer, + image_processor=image_processor, + unet=unet, + scheduler=scheduler, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.num_channels_latents = vae.latent_channels + self.text_encoder_seq_len = tokenizer.model_max_length + self.text_encoder_hidden_size = text_encoder.config.hidden_size + self.image_encoder_hidden_size = image_encoder.config.hidden_size + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def _infer_batch_size(self, mode, prompt, prompt_embeds, image, num_samples): + r"""Infers the batch size depending on mode.""" + if mode in ["text2img"]: + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + # Either prompt or prompt_embeds must be present for text2img. + batch_size = prompt_embeds.shape[0] + elif mode in ["img2text"]: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + else: + # Image must be available and type either PIL.Image.Image or torch.FloatTensor. + # Not currently supporting something like image_embeds. + batch_size = image.shape[0] + else: + # For unconditional (and marginal) generation, we use num_samples + batch_size = num_samples + return batch_size + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents + # Rename: prepare_image_latents() -> encode_image_vae_latents() + def encode_image_vae_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def encode_image_clip_latents( + self, + image, + resolution, + batch_size, + num_images_per_prompt, + dtype, + device, + generator=None, + ): + # Map image to CLIP embedding. + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + self.image_encoder( + **self.image_processor.preprocess( + image[i : i + 1], do_center_crop=True, crop_size=resolution, return_tensors="pt" + ) + ) + for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + # TODO: figure out self.image_processor.preprocess kwargs + inputs = self.image_processor.preprocess( + image, do_center_crop=True, crop_size=resolution, return_tensors="pt" + ) + image_latents = self.image_encoder(**inputs) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return image_latents + + # Note that the CLIP latents are not decoded for image generation. + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + # Rename: decode_latents() -> decode_image_latents() + def decode_image_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_text_latents(self, batch_size, seq_len, hidden_size, dtype, device, generator, latents=None): + # Prepare latents for the CLIP embedded prompt. + shape = (batch_size, seq_len, hidden_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + # Rename: prepare_latents() -> prepare_image_vae_latents + def prepare_image_vae_latents( + self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_clip_latents(self, batch_size, clip_img_dim, dtype, device, generator, latents=None): + # Prepare latents for the CLIP embedded image. + shape = (batch_size, 1, clip_img_dim) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _split(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) + and (B, 1, clip_img_dim) + """ + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + + img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_hidden_size], dim=1) + + img_vae = einops.rearrange( + img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width + ) + img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size) + return img_vae, img_clip + + def _combine(self, img_vae, img_clip): + r""" + Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, + clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). + """ + img_vae = einops.rearrange(img_vae, "B C H W -> B (C H W)") + img_clip = einops.rearrange(img_clip, "B L D -> B (L D)") + return torch.concat([img_vae, img_clip], dim=-1) + + def _split_joint(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, + img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is + of shape (B, text_seq_len, text_dim). + """ + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size + + img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_hidden_size, text_dim], dim=1) + + img_vae = einops.rearrange( + img_vae, "B (C H W) -> B C H W", C=self.image_encoder_hidden_size, H=latent_height, W=latent_width + ) + img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size) + text = einops.rearrange(text, "B (L D) -> B L D", L=self.text_encoder_seq_len, D=self.text_encoder_hidden_size) + return img_vae, img_clip, text + + def _combine_joint(self, img_vae, img_clip, text): + r""" + Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, + clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, + C * H * W + L_img * clip_img_dim + L_text * text_dim). + """ + img_vae = einops.rearrange(img_vae, "B C H W -> B (C H W)") + img_clip = einops.rearrange(img_clip, "B L D -> B (L D)") + text = einops.rearrange(text, "B L D -> B (L D)") + return torch.concat([img_vae, img_clip, text], dim=-1) + + def get_noise_pred( + self, + mode, + latents, + t, + prompt_embeds, + img_vae, + img_clip, + timesteps, + guidance_scale, + generator, + device, + height, + width, + ): + # Predicts noise using the noise prediction model for the given mode. + if mode == "joint": + # Joint text-image generation + img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, text_latents, t_img=t, t_text=t + ) + + x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) + + if guidance_scale <= 1.0: + return x_out + + # Classifier-free guidance + # TODO: need to replace this with the appropriate generator logic and randn_tensor + img_vae_T = torch.randn_like(img_vae, device=device) + img_clip_T = torch.randn_like(img_clip, device=device) + text_T = torch.randn_like(prompt_embeds, device=device) + t_img_uncond = torch.ones_like(t) * timesteps + t_text_uncond = torch.ones_like(t) * timesteps + + _, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=t_img_uncond, t_text=t) + img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( + img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=t_text_uncond + ) + + x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) + + return guidance_scale * x_out * (1.0 - guidance_scale) * x_out_uncond + elif mode == "text2img": + # Text-conditioned image generation + img_vae_latents, img_clip_latents = self._split(latents, height, width) + t_text = torch.zeros(t.size(0), dtype=torch.int, device=device) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text + ) + + img_out = self._combine(img_vae_out, img_clip_out) + + if guidance_scale <= 1.0: + return img_out + + # Classifier-free guidance + # TODO: need to replace this with the appropriate generator logic and randn_tensor + text_T = torch.randn_like(prompt_embeds) + t_text_uncond = torch.ones_like(t) * timesteps + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_latents, img_clip_latents, text_T, t_img=timesteps, t_text=t_text_uncond + ) + + img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) + + return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond + elif mode == "img2text": + # Image-conditioned text generation + t_img = torch.zeros(t.size(0), dtype=torch.int, device=device) + + img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t) + + if guidance_scale <= 1.0: + return text_out + + # Classifier-free guidance + # TODO: need to replace this with the appropriate generator logic and randn_tensor + img_vae_T = torch.randn_like(img_vae) + img_clip_T = torch.randn_like(img_clip) + t_img_uncond = torch.ones_like(t) * timesteps + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_T, img_clip_T, latents, t_img=t_img_uncond, t_text=timesteps + ) + + return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond + elif mode == "text": + # Unconditional ("marginal") text generation (no CFG) + t_img = torch.ones_like(t) * timesteps + + img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t) + + return text_out + elif mode == "img": + # Unconditional ("marginal") image generation (no CFG) + img_vae_latents, img_clip_latents = self._split(latents, height, width) + t_text = torch.ones_like(t) * timesteps + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text + ) + + img_out = self._combine(img_vae_out, img_clip_out) + return img_out + + # Temporarily copied from StableDiffusionPipeline. + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + # Check inputs before running the generative process. + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + mode: str = "text2img", # text, img, text2img, img2text, joint + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + num_prompts_per_image: Optional[int] = 1, + num_samples: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_latents: Optional[torch.FloatTensor] = None, + vae_latents: Optional[torch.FloatTensor] = None, + clip_latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. Used in `text2img` mode. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used when performing image-conditioned + text generation (`img2text`). + mode (`str`): + The generation task to be performed; use `text` for unconditional ("marginal") text generation, `img` + for unconditional ("marginal") image generation, `text2img` for text-conditioned image generation, + `img2text` for image-conditioned text generation, and `joint` for joint image-text generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) + mode. + num_prompts_per_image (`int`, *optional*, defaults to 1): + The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) + mode. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for text + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + vae_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + clip_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Examples: Returns: + [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: + [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of generated texts. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + # Copied from base StableDiffusionPipeline for now + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + + batch_size = self._infer_batch_size(mode, prompt, prompt_embeds, image, num_samples) + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + # Note that this diffusers from the formulation in the unidiffusers paper! + # check if scheduler is in sigmas space + hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt, if available; otherwise prepare text latents + + if mode in ["text2img"]: + # 3.1. Encode input prompt, if available + assert prompt is not None or prompt_embeds is not None + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + False, # don't support standard classifier-free guidance for now + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + else: + # 3.2. Prepare text image latent variables, if necessary + prompt_embeds = self.prepare_text_latents( + batch_size, + self.text_encoder_seq_len, + self.text_encoder_hidden_size, + torch.float32, # Placeholder, need to determine correct thing to do for dtype + device, + generator, + prompt_latents, + ) + + # 4. Encode image, if available; otherwise prepare image latents + if mode in ["img2text"]: + # 4.1. Encode images, if available + assert image is not None + # Encode image using VAE + image_vae = preprocess(image) + height, width = image.shape[-2:] + image_vae_latents = self.encode_image_vae_latents( + image_vae, + batch_size, + num_prompts_per_image, # not num_images_per_prompt + prompt_embeds.dtype, + device, + False, # Copied from InstructPix2Pix, don't use their version of CFG + generator, + ) + + # Encode image using CLIP + image_clip_latents = self.encode_image_clip_latents( + image, + height, # assume image is square for now... + batch_size, + num_prompts_per_image, # not num_images_per_prompt + prompt_embeds.dtype, + device, + generator, + ) + else: + # 4.2. Prepare image latent variables, if necessary + # Prepare image VAE latents + image_vae_latents = self.prepare_image_vae_latents( + batch_size * num_images_per_prompt, + self.num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + vae_latents, + ) + + # Prepare image CLIP latents + image_clip_latents = self.prepare_image_clip_latents( + batch_size * num_images_per_prompt, + self.image_encoder_hidden_size, + prompt_embeds.dtype, + device, + generator, + clip_latents, + ) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + if mode == "joint": + latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) + elif mode in ["text2img", "img"]: + latents = self._combine(image_vae_latents, image_clip_latents) + elif mode in ["img2text", "text"]: + latents = prompt_embeds + + # 7. Check that shapes of latents and image match the UNet channels. + # TODO + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # predict the noise residual + # Also applies classifier-free guidance as described in the UniDiffuser paper + noise_pred = self.get_noise_pred( + mode, + latents, + t, + prompt_embeds, + image_vae_latents, + image_clip_latents, + timesteps, + guidance_scale, + generator, + device, + height, + width, + ) + + # TODO: do we need to worry about sigma space stuff for the scheduler? + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + gen_image = None + gen_text = None + if mode == "joint": + image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) + # Map latent VAE image back to pixel space + gen_image = self.decode_image_latents(image_vae_latents) + # Generate text using the text decoder + gen_text = self.text_decoder.generate_captions(text_latents, device=device) + elif mode in ["text2img", "img"]: + image_vae_latents, image_clip_latents = self._split(latents, height, width) + gen_image = self.decode_image_latents(image_vae_latents) + elif mode in ["img2text", "text"]: + text_latents = latents + gen_text = self.text_decoder.generate_captions(text_latents, device=device) + + # 11. Run safety checker + # TODO + + # 12. Convert to PIL + if output_type == "pil": + gen_image = self.numpy_to_pil(gen_image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (gen_image, gen_text) + + return ImageTextPipelineOutput(images=gen_image, text=gen_text) diff --git a/tests/pipelines/unidiffuser/__init__.py b/tests/pipelines/unidiffuser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py new file mode 100644 index 0000000000..9f353f3ff1 --- /dev/null +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -0,0 +1,119 @@ +import unittest + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModel, + GPT2Config, + GPT2LMHeadModel, + GPT2Tokenizer, +) + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, +) +from diffusers.utils import slow +from diffusers.utils.testing_utils import require_torch_gpu + +from ...test_pipelines_common import PipelineTesterMixin + + +class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = UniDiffuserPipeline + params = None # TODO + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UniDiffuserModel( + sample_size=16, + num_layers=2, + patch_size=4, + attention_head_dim=8, + num_attention_heads=2, + in_channels=4, + out_channels=8, + attention_bias=True, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_zero", + norm_elementwise_affine=False, + text_dim=32, # TODO: needs to line up with CLIPTextConfig + clip_img_dim=32, # TODO: needs to line up with CLIPVisionConfig + ) + + scheduler = DDIMScheduler() + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + # TODO: get appropriate testing version for these + text_decoder_tokenizer = GPT2Tokenizer() + text_decoder_model_config = GPT2Config() + text_decoder_model = GPT2LMHeadModel(text_decoder_model_config) + text_decoder = UniDiffuserTextDecoder( + text_decoder_tokenizer, + text_decoder_model, + prefix_length=77, # TODO: fix + ) + + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig() + image_encoder = CLIPVisionModel(image_encoder_config) + # TODO: does this actually work? + image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "vae": vae, + "text_encoder": text_encoder, + "text_decoder": text_decoder, + "image_encoder": image_encoder, + "tokenizer": tokenizer, + "image_processor": image_processor, + "unet": unet, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + pass + + def test_unidiffuser_default_case(self): + pass + + +@slow +@require_torch_gpu +class UniDiffuserPipelineSlowTests(unittest.TestCase): + pass