From 0140e33564a177eca32cd8273a69da7d09e064b7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 6 Apr 2023 01:45:16 -0700 Subject: [PATCH] Refactor pipeline based on review (change __init__ design for UniDiffuserTextDecoder, add text_tokenizer to __init__ for UniDiffuserPipeline) and add mode inference and mode setting functions to UniDiffuserPipeline. --- .../unidiffuser/modeling_text_decoder.py | 87 ++++++++------- .../pipelines/unidiffuser/modeling_uvit.py | 12 +-- .../unidiffuser/pipeline_unidiffuser.py | 102 +++++++++++++++--- 3 files changed, 141 insertions(+), 60 deletions(-) diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py index 6a6eea44a2..3782e9e5ff 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -3,60 +3,64 @@ from typing import Optional import numpy as np import torch from torch import nn -from transformers import GPT2LMHeadModel, GPT2Tokenizer +from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer +from transformers.modeling_utils import ModuleUtilsMixin from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin # Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py -class UniDiffuserTextDecoder(ModelMixin, ConfigMixin): +class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): @register_to_config def __init__( self, - tokenizer: GPT2Tokenizer, - text_decoder: GPT2LMHeadModel, prefix_length: int, - hidden_dim: Optional[int] = None, - use_hidden_dim: bool = True, + prefix_hidden_dim: Optional[int] = None, + n_positions: int = 1024, # Start of GPT2 config args + n_embd: int = 768, + n_layer: int = 12, + n_head: int = 12, + n_inner: Optional[int] = None, + activation_function: str = "gelu_new", + resid_pdrop: float = 0.1, + embd_pdrop: float = 0.1, + attn_pdrop: float = 0.1, + layer_norm_epsilon: float = 1e-5, + initializer_range: float = 0.02, ): """ Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to generate text from the UniDiffuser image-text embedding. Parameters: - tokenizer ([`GPT2Tokenizer`]): - Tokenizer of class - [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for - the GPT-like text decoder model. - text_decoder ([`GPT2LMHeadModel`]): - Text decoder model of class - [GPT2LMHeadModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel) - used to generate text from the UniDiffuser text embedding. prefix_length (`int`): - TODO + Max number of prefix tokens that will be supplied to the model. hidden_dim (`int`, *optional*): Hidden dim of the MLP if we encode the prefix. - use_hidden_dim (`bool`, *optional*, defaults to `True`): - Whether or not to use a MLP to encode the prefix. + TODO: add GPT2 config args """ super().__init__() + self.prefix_length = prefix_length + self.prefix_hidden_dim = prefix_hidden_dim + self.encode_prefix = nn.Linear(n_embd, self.hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity() + self.decode_prefix = nn.Linear(self.hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() - eos = "<|EOS|>" - special_tokens_dict = {"eos_token": eos} - self.tokenizer = tokenizer - self.tokenizer.add_special_tokens(special_tokens_dict) - - self.transformer = text_decoder - # TODO: need to set the eos_token_id correctly - self.transformer.config.eos_token_id = self.tokenizer.eos_token_id - self.transformer.resize_token_embeddings(len(self.tokenizer)) - - self.use_hidden_dim = use_hidden_dim - self.hidden_dim = hidden_dim if hidden_dim is not None else self.transformer.config.n_embd - self.encode_prefix = nn.Linear(768, self.hidden_dim) if use_hidden_dim else nn.Identity() - self.decode_prefix = nn.Linear(self.hidden_dim, 768) if use_hidden_dim else nn.Identity() + gpt_config = GPT2Config( + n_positions=n_positions, + n_embd=n_embd, + n_layer=n_layer, + n_head=n_head, + n_inner=n_inner, + activation_function=activation_function, + resid_pdrop=resid_pdrop, + embd_pdrop=embd_pdrop, + attn_pdrop=attn_pdrop, + layer_norm_epsilon=layer_norm_epsilon, + initializer_range=initializer_range, + ) + self.transformer = GPT2LMHeadModel(gpt_config) def forward( self, @@ -85,7 +89,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin): dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) labels = torch.cat((dummy_token, tokens), dim=1) out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) - if self.use_hidden_dim: + if self.prefix_hidden_dim is not None: return out, hidden else: return out @@ -94,11 +98,15 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin): return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) @torch.no_grad() - def generate_captions(self, features, device): + def generate_captions(self, tokenizer, features, device): """ Generate captions given text embedding features. Returns list[L]. Args: + tokenizer (`GPT2Tokenizer`): + Tokenizer of class + [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) + for tokenizing input to the text decoder model. features (`torch.Tensor` of shape `(B, L, D)`): Text embedding features to generate captions from. device: @@ -110,12 +118,13 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin): for feature in features: feature = self.decode_prefix(feature.to(device)) # back to the clip feature # Only support beam search for now - generated_captions.append(self.generate_beam(embed=feature, device=device)[0]) + generated_captions.append(self.generate_beam(tokenizer, embed=feature, device=device)[0]) return generated_captions @torch.no_grad() def generate_beam( self, + tokenizer, prompt=None, embed=None, device=None, @@ -124,8 +133,14 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin): temperature: float = 1.0, stop_token: str = "<|EOS|>", ): + """ + Generates text using the given tokenizer and text prompt or token embedding via beam search. + + TODO: args + """ # Generates text until stop_token is reached using beam search with the desired beam size. - stop_token_index = self.tokenizer.encode(stop_token)[0] + # TODO: get the stop token index directly from tokenizer rather than manually specifying the EOS token? + stop_token_index = tokenizer.encode(stop_token)[0] tokens = None scores = None seq_lengths = torch.ones(beam_size, device=device) @@ -135,7 +150,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin): generated = embed else: assert prompt is not None - tokens = torch.tensor(self.tokenizer.encode(prompt)) + tokens = torch.tensor(tokenizer.encode(prompt)) tokens = tokens.unsqueeze(0).to(device) generated = self.transformer.transformer.wte(tokens) diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index dc53937b29..977c534ab4 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -102,20 +102,12 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration + # TODO: clean up input cases? self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None self.is_input_patches = in_channels is not None and patch_size is not None - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" + # Remove layer_norm/num_embeds_ada_norm deprecation message. if self.is_input_continuous and self.is_input_vectorized: raise ValueError( diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index c31a26131c..5e65824d94 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -11,6 +11,7 @@ from transformers import ( CLIPTextModel, CLIPTokenizer, CLIPVisionModel, + GPT2Tokenizer, ) from ...models import AutoencoderKL @@ -115,10 +116,11 @@ class UniDiffuserPipeline(DiffusionPipeline): self, vae: AutoencoderKL, text_encoder: CLIPTextModel, - text_decoder: UniDiffuserTextDecoder, image_encoder: CLIPVisionModel, - tokenizer: CLIPTokenizer, image_processor: CLIPImageProcessor, + clip_tokenizer: CLIPTokenizer, + text_decoder: UniDiffuserTextDecoder, + text_tokenizer: GPT2Tokenizer, unet: UniDiffuserModel, scheduler: KarrasDiffusionSchedulers, ): @@ -127,10 +129,11 @@ class UniDiffuserPipeline(DiffusionPipeline): self.register_modules( vae=vae, text_encoder=text_encoder, - text_decoder=text_decoder, image_encoder=image_encoder, - tokenizer=tokenizer, image_processor=image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, unet=unet, scheduler=scheduler, ) @@ -138,10 +141,12 @@ class UniDiffuserPipeline(DiffusionPipeline): self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.num_channels_latents = vae.latent_channels - self.text_encoder_seq_len = tokenizer.model_max_length + self.text_encoder_seq_len = clip_tokenizer.model_max_length self.text_encoder_hidden_size = text_encoder.config.hidden_size self.image_encoder_hidden_size = image_encoder.config.hidden_size + self.mode = None + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -178,6 +183,76 @@ class UniDiffuserPipeline(DiffusionPipeline): if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs + + def _infer_mode(self, prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents): + r"""Infer the mode from the inputs to `__call__`.""" + prompt_available = (prompt is not None) or (prompt_embeds is not None) + image_available = image is not None + input_available = prompt_available or image_available + + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + image_latents_available = vae_latents_available and clip_latents_available + all_latents_available = prompt_latents_available and image_latents_available + + if self.mode is not None: + # Preferentially use the mode set by the user + mode = self.mode + elif prompt_available: + mode = "text2img" + elif image_available: + mode = "img2text" + else: + # Neither prompt nor image supplied, infer based on availability of latents + if all_latents_available: + mode = "joint" + elif prompt_latents_available: + mode = "text" + elif image_latents_available: + mode = "img" + else: + # No inputs or latents available + mode = "img" + + # Give warnings for ambiguous cases + if self.mode is None and prompt_available and image_available: + logger.warning( + f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," + f" defaulting to mode '{mode}'." + ) + + if self.mode is None and not input_available: + if vae_latents_available != clip_latents_available: + # Exactly one of vae_latents and clip_latents is supplied + logger.warning( + f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" + f" are expected to be supplied. Defaulting to mode '{mode}'." + ) + elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: + # No inputs or latents supplied + logger.warning( + f"No inputs or latents have been supplied, and mode has not been manually set," + f" defaulting to mode '{mode}'." + ) + + return mode + + # Functions to manually set the mode + def set_text_mode(self): + self.mode = "text" + + def set_img_mode(self): + self.mode = "img" + + def set_text_to_image_mode(self): + self.mode = "text2img" + + def set_image_to_text_mode(self): + self.mode = "img2text" + + def set_joint_mode(self): + self.mode = "joint" def _infer_batch_size(self, mode, prompt, prompt_embeds, image, num_samples): r"""Infers the batch size depending on mode.""" @@ -741,7 +816,6 @@ class UniDiffuserPipeline(DiffusionPipeline): self, prompt: Optional[Union[str, List[str]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - mode: str = "text2img", # text, img, text2img, img2text, joint height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -772,10 +846,6 @@ class UniDiffuserPipeline(DiffusionPipeline): image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used when performing image-conditioned text generation (`img2text`). - mode (`str`): - The generation task to be performed; use `text` for unconditional ("marginal") text generation, `img` - for unconditional ("marginal") image generation, `text2img` for text-conditioned image generation, - `img2text` for image-conditioned text generation, and `joint` for joint image-text generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -855,15 +925,19 @@ class UniDiffuserPipeline(DiffusionPipeline): # 2. Define call parameters + # Recalculate mode for each call to the pipeline. + mode = self._infer_mode(prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents) batch_size = self._infer_batch_size(mode, prompt, prompt_embeds, image, num_samples) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. - # Note that this diffusers from the formulation in the unidiffusers paper! + # Note that this differs from the formulation in the unidiffusers paper! + # do_classifier_free_guidance = guidance_scale > 1.0 + # check if scheduler is in sigmas space - hasattr(self.scheduler, "sigmas") + # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 3. Encode input prompt, if available; otherwise prepare text latents @@ -1000,13 +1074,13 @@ class UniDiffuserPipeline(DiffusionPipeline): # Map latent VAE image back to pixel space gen_image = self.decode_image_latents(image_vae_latents) # Generate text using the text decoder - gen_text = self.text_decoder.generate_captions(text_latents, device=device) + gen_text = self.text_decoder.generate_captions(self.text_tokenizer, text_latents, device=device) elif mode in ["text2img", "img"]: image_vae_latents, image_clip_latents = self._split(latents, height, width) gen_image = self.decode_image_latents(image_vae_latents) elif mode in ["img2text", "text"]: text_latents = latents - gen_text = self.text_decoder.generate_captions(text_latents, device=device) + gen_text = self.text_decoder.generate_captions(self.text_tokenizer, text_latents, device=device) # 11. Run safety checker # TODO