1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Refactor pipeline based on review (change __init__ design for UniDiffuserTextDecoder, add text_tokenizer to __init__ for UniDiffuserPipeline) and add mode inference and mode setting functions to UniDiffuserPipeline.

This commit is contained in:
Daniel Gu
2023-04-06 01:45:16 -07:00
parent afe5ba0f20
commit 0140e33564
3 changed files with 141 additions and 60 deletions

View File

@@ -3,60 +3,64 @@ from typing import Optional
import numpy as np
import torch
from torch import nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers.modeling_utils import ModuleUtilsMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
@register_to_config
def __init__(
self,
tokenizer: GPT2Tokenizer,
text_decoder: GPT2LMHeadModel,
prefix_length: int,
hidden_dim: Optional[int] = None,
use_hidden_dim: bool = True,
prefix_hidden_dim: Optional[int] = None,
n_positions: int = 1024, # Start of GPT2 config args
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
n_inner: Optional[int] = None,
activation_function: str = "gelu_new",
resid_pdrop: float = 0.1,
embd_pdrop: float = 0.1,
attn_pdrop: float = 0.1,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
):
"""
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
generate text from the UniDiffuser image-text embedding.
Parameters:
tokenizer ([`GPT2Tokenizer`]):
Tokenizer of class
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for
the GPT-like text decoder model.
text_decoder ([`GPT2LMHeadModel`]):
Text decoder model of class
[GPT2LMHeadModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel)
used to generate text from the UniDiffuser text embedding.
prefix_length (`int`):
TODO
Max number of prefix tokens that will be supplied to the model.
hidden_dim (`int`, *optional*):
Hidden dim of the MLP if we encode the prefix.
use_hidden_dim (`bool`, *optional*, defaults to `True`):
Whether or not to use a MLP to encode the prefix.
TODO: add GPT2 config args
"""
super().__init__()
self.prefix_length = prefix_length
self.prefix_hidden_dim = prefix_hidden_dim
self.encode_prefix = nn.Linear(n_embd, self.hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity()
self.decode_prefix = nn.Linear(self.hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
eos = "<|EOS|>"
special_tokens_dict = {"eos_token": eos}
self.tokenizer = tokenizer
self.tokenizer.add_special_tokens(special_tokens_dict)
self.transformer = text_decoder
# TODO: need to set the eos_token_id correctly
self.transformer.config.eos_token_id = self.tokenizer.eos_token_id
self.transformer.resize_token_embeddings(len(self.tokenizer))
self.use_hidden_dim = use_hidden_dim
self.hidden_dim = hidden_dim if hidden_dim is not None else self.transformer.config.n_embd
self.encode_prefix = nn.Linear(768, self.hidden_dim) if use_hidden_dim else nn.Identity()
self.decode_prefix = nn.Linear(self.hidden_dim, 768) if use_hidden_dim else nn.Identity()
gpt_config = GPT2Config(
n_positions=n_positions,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
n_inner=n_inner,
activation_function=activation_function,
resid_pdrop=resid_pdrop,
embd_pdrop=embd_pdrop,
attn_pdrop=attn_pdrop,
layer_norm_epsilon=layer_norm_epsilon,
initializer_range=initializer_range,
)
self.transformer = GPT2LMHeadModel(gpt_config)
def forward(
self,
@@ -85,7 +89,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
if self.use_hidden_dim:
if self.prefix_hidden_dim is not None:
return out, hidden
else:
return out
@@ -94,11 +98,15 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
@torch.no_grad()
def generate_captions(self, features, device):
def generate_captions(self, tokenizer, features, device):
"""
Generate captions given text embedding features. Returns list[L].
Args:
tokenizer (`GPT2Tokenizer`):
Tokenizer of class
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer)
for tokenizing input to the text decoder model.
features (`torch.Tensor` of shape `(B, L, D)`):
Text embedding features to generate captions from.
device:
@@ -110,12 +118,13 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
for feature in features:
feature = self.decode_prefix(feature.to(device)) # back to the clip feature
# Only support beam search for now
generated_captions.append(self.generate_beam(embed=feature, device=device)[0])
generated_captions.append(self.generate_beam(tokenizer, embed=feature, device=device)[0])
return generated_captions
@torch.no_grad()
def generate_beam(
self,
tokenizer,
prompt=None,
embed=None,
device=None,
@@ -124,8 +133,14 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
temperature: float = 1.0,
stop_token: str = "<|EOS|>",
):
"""
Generates text using the given tokenizer and text prompt or token embedding via beam search.
TODO: args
"""
# Generates text until stop_token is reached using beam search with the desired beam size.
stop_token_index = self.tokenizer.encode(stop_token)[0]
# TODO: get the stop token index directly from tokenizer rather than manually specifying the EOS token?
stop_token_index = tokenizer.encode(stop_token)[0]
tokens = None
scores = None
seq_lengths = torch.ones(beam_size, device=device)
@@ -135,7 +150,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
generated = embed
else:
assert prompt is not None
tokens = torch.tensor(self.tokenizer.encode(prompt))
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = self.transformer.transformer.wte(tokens)

View File

@@ -102,20 +102,12 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
# TODO: clean up input cases?
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
# Remove layer_norm/num_embeds_ada_norm deprecation message.
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(

View File

@@ -11,6 +11,7 @@ from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModel,
GPT2Tokenizer,
)
from ...models import AutoencoderKL
@@ -115,10 +116,11 @@ class UniDiffuserPipeline(DiffusionPipeline):
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
text_decoder: UniDiffuserTextDecoder,
image_encoder: CLIPVisionModel,
tokenizer: CLIPTokenizer,
image_processor: CLIPImageProcessor,
clip_tokenizer: CLIPTokenizer,
text_decoder: UniDiffuserTextDecoder,
text_tokenizer: GPT2Tokenizer,
unet: UniDiffuserModel,
scheduler: KarrasDiffusionSchedulers,
):
@@ -127,10 +129,11 @@ class UniDiffuserPipeline(DiffusionPipeline):
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_decoder=text_decoder,
image_encoder=image_encoder,
tokenizer=tokenizer,
image_processor=image_processor,
clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder,
text_tokenizer=text_tokenizer,
unet=unet,
scheduler=scheduler,
)
@@ -138,10 +141,12 @@ class UniDiffuserPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.num_channels_latents = vae.latent_channels
self.text_encoder_seq_len = tokenizer.model_max_length
self.text_encoder_seq_len = clip_tokenizer.model_max_length
self.text_encoder_hidden_size = text_encoder.config.hidden_size
self.image_encoder_hidden_size = image_encoder.config.hidden_size
self.mode = None
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -178,6 +183,76 @@ class UniDiffuserPipeline(DiffusionPipeline):
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def _infer_mode(self, prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents):
r"""Infer the mode from the inputs to `__call__`."""
prompt_available = (prompt is not None) or (prompt_embeds is not None)
image_available = image is not None
input_available = prompt_available or image_available
prompt_latents_available = prompt_latents is not None
vae_latents_available = vae_latents is not None
clip_latents_available = clip_latents is not None
image_latents_available = vae_latents_available and clip_latents_available
all_latents_available = prompt_latents_available and image_latents_available
if self.mode is not None:
# Preferentially use the mode set by the user
mode = self.mode
elif prompt_available:
mode = "text2img"
elif image_available:
mode = "img2text"
else:
# Neither prompt nor image supplied, infer based on availability of latents
if all_latents_available:
mode = "joint"
elif prompt_latents_available:
mode = "text"
elif image_latents_available:
mode = "img"
else:
# No inputs or latents available
mode = "img"
# Give warnings for ambiguous cases
if self.mode is None and prompt_available and image_available:
logger.warning(
f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually,"
f" defaulting to mode '{mode}'."
)
if self.mode is None and not input_available:
if vae_latents_available != clip_latents_available:
# Exactly one of vae_latents and clip_latents is supplied
logger.warning(
f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none"
f" are expected to be supplied. Defaulting to mode '{mode}'."
)
elif not prompt_latents_available and not vae_latents_available and not clip_latents_available:
# No inputs or latents supplied
logger.warning(
f"No inputs or latents have been supplied, and mode has not been manually set,"
f" defaulting to mode '{mode}'."
)
return mode
# Functions to manually set the mode
def set_text_mode(self):
self.mode = "text"
def set_img_mode(self):
self.mode = "img"
def set_text_to_image_mode(self):
self.mode = "text2img"
def set_image_to_text_mode(self):
self.mode = "img2text"
def set_joint_mode(self):
self.mode = "joint"
def _infer_batch_size(self, mode, prompt, prompt_embeds, image, num_samples):
r"""Infers the batch size depending on mode."""
@@ -741,7 +816,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
mode: str = "text2img", # text, img, text2img, img2text, joint
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
@@ -772,10 +846,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used when performing image-conditioned
text generation (`img2text`).
mode (`str`):
The generation task to be performed; use `text` for unconditional ("marginal") text generation, `img`
for unconditional ("marginal") image generation, `text2img` for text-conditioned image generation,
`img2text` for image-conditioned text generation, and `joint` for joint image-text generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -855,15 +925,19 @@ class UniDiffuserPipeline(DiffusionPipeline):
# 2. Define call parameters
# Recalculate mode for each call to the pipeline.
mode = self._infer_mode(prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents)
batch_size = self._infer_batch_size(mode, prompt, prompt_embeds, image, num_samples)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# Note that this diffusers from the formulation in the unidiffusers paper!
# Note that this differs from the formulation in the unidiffusers paper!
# do_classifier_free_guidance = guidance_scale > 1.0
# check if scheduler is in sigmas space
hasattr(self.scheduler, "sigmas")
# scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
# 3. Encode input prompt, if available; otherwise prepare text latents
@@ -1000,13 +1074,13 @@ class UniDiffuserPipeline(DiffusionPipeline):
# Map latent VAE image back to pixel space
gen_image = self.decode_image_latents(image_vae_latents)
# Generate text using the text decoder
gen_text = self.text_decoder.generate_captions(text_latents, device=device)
gen_text = self.text_decoder.generate_captions(self.text_tokenizer, text_latents, device=device)
elif mode in ["text2img", "img"]:
image_vae_latents, image_clip_latents = self._split(latents, height, width)
gen_image = self.decode_image_latents(image_vae_latents)
elif mode in ["img2text", "text"]:
text_latents = latents
gen_text = self.text_decoder.generate_captions(text_latents, device=device)
gen_text = self.text_decoder.generate_captions(self.text_tokenizer, text_latents, device=device)
# 11. Run safety checker
# TODO