mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Refactor pipeline based on review (change __init__ design for UniDiffuserTextDecoder, add text_tokenizer to __init__ for UniDiffuserPipeline) and add mode inference and mode setting functions to UniDiffuserPipeline.
This commit is contained in:
@@ -3,60 +3,64 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers.modeling_utils import ModuleUtilsMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
|
||||
|
||||
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
|
||||
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
|
||||
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: GPT2Tokenizer,
|
||||
text_decoder: GPT2LMHeadModel,
|
||||
prefix_length: int,
|
||||
hidden_dim: Optional[int] = None,
|
||||
use_hidden_dim: bool = True,
|
||||
prefix_hidden_dim: Optional[int] = None,
|
||||
n_positions: int = 1024, # Start of GPT2 config args
|
||||
n_embd: int = 768,
|
||||
n_layer: int = 12,
|
||||
n_head: int = 12,
|
||||
n_inner: Optional[int] = None,
|
||||
activation_function: str = "gelu_new",
|
||||
resid_pdrop: float = 0.1,
|
||||
embd_pdrop: float = 0.1,
|
||||
attn_pdrop: float = 0.1,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
initializer_range: float = 0.02,
|
||||
):
|
||||
"""
|
||||
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
|
||||
generate text from the UniDiffuser image-text embedding.
|
||||
|
||||
Parameters:
|
||||
tokenizer ([`GPT2Tokenizer`]):
|
||||
Tokenizer of class
|
||||
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for
|
||||
the GPT-like text decoder model.
|
||||
text_decoder ([`GPT2LMHeadModel`]):
|
||||
Text decoder model of class
|
||||
[GPT2LMHeadModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel)
|
||||
used to generate text from the UniDiffuser text embedding.
|
||||
prefix_length (`int`):
|
||||
TODO
|
||||
Max number of prefix tokens that will be supplied to the model.
|
||||
hidden_dim (`int`, *optional*):
|
||||
Hidden dim of the MLP if we encode the prefix.
|
||||
use_hidden_dim (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use a MLP to encode the prefix.
|
||||
TODO: add GPT2 config args
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.prefix_length = prefix_length
|
||||
self.prefix_hidden_dim = prefix_hidden_dim
|
||||
self.encode_prefix = nn.Linear(n_embd, self.hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity()
|
||||
self.decode_prefix = nn.Linear(self.hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
|
||||
|
||||
eos = "<|EOS|>"
|
||||
special_tokens_dict = {"eos_token": eos}
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
self.transformer = text_decoder
|
||||
# TODO: need to set the eos_token_id correctly
|
||||
self.transformer.config.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.transformer.resize_token_embeddings(len(self.tokenizer))
|
||||
|
||||
self.use_hidden_dim = use_hidden_dim
|
||||
self.hidden_dim = hidden_dim if hidden_dim is not None else self.transformer.config.n_embd
|
||||
self.encode_prefix = nn.Linear(768, self.hidden_dim) if use_hidden_dim else nn.Identity()
|
||||
self.decode_prefix = nn.Linear(self.hidden_dim, 768) if use_hidden_dim else nn.Identity()
|
||||
gpt_config = GPT2Config(
|
||||
n_positions=n_positions,
|
||||
n_embd=n_embd,
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_inner=n_inner,
|
||||
activation_function=activation_function,
|
||||
resid_pdrop=resid_pdrop,
|
||||
embd_pdrop=embd_pdrop,
|
||||
attn_pdrop=attn_pdrop,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
initializer_range=initializer_range,
|
||||
)
|
||||
self.transformer = GPT2LMHeadModel(gpt_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -85,7 +89,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
|
||||
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
||||
labels = torch.cat((dummy_token, tokens), dim=1)
|
||||
out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
||||
if self.use_hidden_dim:
|
||||
if self.prefix_hidden_dim is not None:
|
||||
return out, hidden
|
||||
else:
|
||||
return out
|
||||
@@ -94,11 +98,15 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
|
||||
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_captions(self, features, device):
|
||||
def generate_captions(self, tokenizer, features, device):
|
||||
"""
|
||||
Generate captions given text embedding features. Returns list[L].
|
||||
|
||||
Args:
|
||||
tokenizer (`GPT2Tokenizer`):
|
||||
Tokenizer of class
|
||||
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer)
|
||||
for tokenizing input to the text decoder model.
|
||||
features (`torch.Tensor` of shape `(B, L, D)`):
|
||||
Text embedding features to generate captions from.
|
||||
device:
|
||||
@@ -110,12 +118,13 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
|
||||
for feature in features:
|
||||
feature = self.decode_prefix(feature.to(device)) # back to the clip feature
|
||||
# Only support beam search for now
|
||||
generated_captions.append(self.generate_beam(embed=feature, device=device)[0])
|
||||
generated_captions.append(self.generate_beam(tokenizer, embed=feature, device=device)[0])
|
||||
return generated_captions
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_beam(
|
||||
self,
|
||||
tokenizer,
|
||||
prompt=None,
|
||||
embed=None,
|
||||
device=None,
|
||||
@@ -124,8 +133,14 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
|
||||
temperature: float = 1.0,
|
||||
stop_token: str = "<|EOS|>",
|
||||
):
|
||||
"""
|
||||
Generates text using the given tokenizer and text prompt or token embedding via beam search.
|
||||
|
||||
TODO: args
|
||||
"""
|
||||
# Generates text until stop_token is reached using beam search with the desired beam size.
|
||||
stop_token_index = self.tokenizer.encode(stop_token)[0]
|
||||
# TODO: get the stop token index directly from tokenizer rather than manually specifying the EOS token?
|
||||
stop_token_index = tokenizer.encode(stop_token)[0]
|
||||
tokens = None
|
||||
scores = None
|
||||
seq_lengths = torch.ones(beam_size, device=device)
|
||||
@@ -135,7 +150,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin):
|
||||
generated = embed
|
||||
else:
|
||||
assert prompt is not None
|
||||
tokens = torch.tensor(self.tokenizer.encode(prompt))
|
||||
tokens = torch.tensor(tokenizer.encode(prompt))
|
||||
tokens = tokens.unsqueeze(0).to(device)
|
||||
generated = self.transformer.transformer.wte(tokens)
|
||||
|
||||
|
||||
@@ -102,20 +102,12 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
# TODO: clean up input cases?
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
# Remove layer_norm/num_embeds_ada_norm deprecation message.
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
|
||||
@@ -11,6 +11,7 @@ from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModel,
|
||||
GPT2Tokenizer,
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
@@ -115,10 +116,11 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
text_decoder: UniDiffuserTextDecoder,
|
||||
image_encoder: CLIPVisionModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
image_processor: CLIPImageProcessor,
|
||||
clip_tokenizer: CLIPTokenizer,
|
||||
text_decoder: UniDiffuserTextDecoder,
|
||||
text_tokenizer: GPT2Tokenizer,
|
||||
unet: UniDiffuserModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
@@ -127,10 +129,11 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_decoder=text_decoder,
|
||||
image_encoder=image_encoder,
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
clip_tokenizer=clip_tokenizer,
|
||||
text_decoder=text_decoder,
|
||||
text_tokenizer=text_tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
@@ -138,10 +141,12 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
self.num_channels_latents = vae.latent_channels
|
||||
self.text_encoder_seq_len = tokenizer.model_max_length
|
||||
self.text_encoder_seq_len = clip_tokenizer.model_max_length
|
||||
self.text_encoder_hidden_size = text_encoder.config.hidden_size
|
||||
self.image_encoder_hidden_size = image_encoder.config.hidden_size
|
||||
|
||||
self.mode = None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -178,6 +183,76 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def _infer_mode(self, prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents):
|
||||
r"""Infer the mode from the inputs to `__call__`."""
|
||||
prompt_available = (prompt is not None) or (prompt_embeds is not None)
|
||||
image_available = image is not None
|
||||
input_available = prompt_available or image_available
|
||||
|
||||
prompt_latents_available = prompt_latents is not None
|
||||
vae_latents_available = vae_latents is not None
|
||||
clip_latents_available = clip_latents is not None
|
||||
image_latents_available = vae_latents_available and clip_latents_available
|
||||
all_latents_available = prompt_latents_available and image_latents_available
|
||||
|
||||
if self.mode is not None:
|
||||
# Preferentially use the mode set by the user
|
||||
mode = self.mode
|
||||
elif prompt_available:
|
||||
mode = "text2img"
|
||||
elif image_available:
|
||||
mode = "img2text"
|
||||
else:
|
||||
# Neither prompt nor image supplied, infer based on availability of latents
|
||||
if all_latents_available:
|
||||
mode = "joint"
|
||||
elif prompt_latents_available:
|
||||
mode = "text"
|
||||
elif image_latents_available:
|
||||
mode = "img"
|
||||
else:
|
||||
# No inputs or latents available
|
||||
mode = "img"
|
||||
|
||||
# Give warnings for ambiguous cases
|
||||
if self.mode is None and prompt_available and image_available:
|
||||
logger.warning(
|
||||
f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually,"
|
||||
f" defaulting to mode '{mode}'."
|
||||
)
|
||||
|
||||
if self.mode is None and not input_available:
|
||||
if vae_latents_available != clip_latents_available:
|
||||
# Exactly one of vae_latents and clip_latents is supplied
|
||||
logger.warning(
|
||||
f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none"
|
||||
f" are expected to be supplied. Defaulting to mode '{mode}'."
|
||||
)
|
||||
elif not prompt_latents_available and not vae_latents_available and not clip_latents_available:
|
||||
# No inputs or latents supplied
|
||||
logger.warning(
|
||||
f"No inputs or latents have been supplied, and mode has not been manually set,"
|
||||
f" defaulting to mode '{mode}'."
|
||||
)
|
||||
|
||||
return mode
|
||||
|
||||
# Functions to manually set the mode
|
||||
def set_text_mode(self):
|
||||
self.mode = "text"
|
||||
|
||||
def set_img_mode(self):
|
||||
self.mode = "img"
|
||||
|
||||
def set_text_to_image_mode(self):
|
||||
self.mode = "text2img"
|
||||
|
||||
def set_image_to_text_mode(self):
|
||||
self.mode = "img2text"
|
||||
|
||||
def set_joint_mode(self):
|
||||
self.mode = "joint"
|
||||
|
||||
def _infer_batch_size(self, mode, prompt, prompt_embeds, image, num_samples):
|
||||
r"""Infers the batch size depending on mode."""
|
||||
@@ -741,7 +816,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
|
||||
mode: str = "text2img", # text, img, text2img, img2text, joint
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
@@ -772,10 +846,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, that will be used when performing image-conditioned
|
||||
text generation (`img2text`).
|
||||
mode (`str`):
|
||||
The generation task to be performed; use `text` for unconditional ("marginal") text generation, `img`
|
||||
for unconditional ("marginal") image generation, `text2img` for text-conditioned image generation,
|
||||
`img2text` for image-conditioned text generation, and `joint` for joint image-text generation.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
@@ -855,15 +925,19 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
# 2. Define call parameters
|
||||
|
||||
# Recalculate mode for each call to the pipeline.
|
||||
mode = self._infer_mode(prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents)
|
||||
batch_size = self._infer_batch_size(mode, prompt, prompt_embeds, image, num_samples)
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
# Note that this diffusers from the formulation in the unidiffusers paper!
|
||||
# Note that this differs from the formulation in the unidiffusers paper!
|
||||
# do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# check if scheduler is in sigmas space
|
||||
hasattr(self.scheduler, "sigmas")
|
||||
# scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
|
||||
|
||||
# 3. Encode input prompt, if available; otherwise prepare text latents
|
||||
|
||||
@@ -1000,13 +1074,13 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
# Map latent VAE image back to pixel space
|
||||
gen_image = self.decode_image_latents(image_vae_latents)
|
||||
# Generate text using the text decoder
|
||||
gen_text = self.text_decoder.generate_captions(text_latents, device=device)
|
||||
gen_text = self.text_decoder.generate_captions(self.text_tokenizer, text_latents, device=device)
|
||||
elif mode in ["text2img", "img"]:
|
||||
image_vae_latents, image_clip_latents = self._split(latents, height, width)
|
||||
gen_image = self.decode_image_latents(image_vae_latents)
|
||||
elif mode in ["img2text", "text"]:
|
||||
text_latents = latents
|
||||
gen_text = self.text_decoder.generate_captions(text_latents, device=device)
|
||||
gen_text = self.text_decoder.generate_captions(self.text_tokenizer, text_latents, device=device)
|
||||
|
||||
# 11. Run safety checker
|
||||
# TODO
|
||||
|
||||
Reference in New Issue
Block a user