1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Add UniDiffuser classes to __init__ files, modify transformer block to support pre- and post-LN, add fast default tests, fix some bugs.

This commit is contained in:
Daniel Gu
2023-04-14 00:54:02 -07:00
parent 0140e33564
commit a492e0c86f
6 changed files with 472 additions and 88 deletions

View File

@@ -115,6 +115,7 @@ else:
AltDiffusionPipeline,
AudioLDMPipeline,
CycleDiffusionPipeline,
ImageTextPipelineOutput,
LDMTextToImagePipeline,
PaintByExamplePipeline,
SemanticStableDiffusionPipeline,
@@ -139,6 +140,9 @@ else:
TextToVideoSDPipeline,
UnCLIPImageVariationPipeline,
UnCLIPPipeline,
UniDiffuserModel,
UniDiffuserPipeline,
UniDiffuserTextDecoder,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,

View File

@@ -70,6 +70,7 @@ else:
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .text_to_video_synthesis import TextToVideoSDPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,

View File

@@ -3,7 +3,7 @@ from typing import Optional
import numpy as np
import torch
from torch import nn
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import GPT2Config, GPT2LMHeadModel
from transformers.modeling_utils import ModuleUtilsMixin
from ...configuration_utils import ConfigMixin, register_to_config
@@ -36,7 +36,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
Parameters:
prefix_length (`int`):
Max number of prefix tokens that will be supplied to the model.
hidden_dim (`int`, *optional*):
prefix_hidden_dim (`int`, *optional*):
Hidden dim of the MLP if we encode the prefix.
TODO: add GPT2 config args
"""
@@ -44,8 +44,12 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
self.prefix_length = prefix_length
self.prefix_hidden_dim = prefix_hidden_dim
self.encode_prefix = nn.Linear(n_embd, self.hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity()
self.decode_prefix = nn.Linear(self.hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
self.encode_prefix = (
nn.Linear(n_embd, self.prefix_hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity()
)
self.decode_prefix = (
nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
)
gpt_config = GPT2Config(
n_positions=n_positions,
@@ -105,8 +109,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
Args:
tokenizer (`GPT2Tokenizer`):
Tokenizer of class
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer)
for tokenizing input to the text decoder model.
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for
tokenizing input to the text decoder model.
features (`torch.Tensor` of shape `(B, L, D)`):
Text embedding features to generate captions from.
device:

View File

@@ -6,21 +6,20 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.attention import AdaLayerNorm, BasicTransformerBlock
from ...models.attention import AdaLayerNorm, FeedForward
from ...models.attention_processor import Attention
from ...models.embeddings import ImagePositionalEmbeddings, PatchEmbed, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModelOutput
from ...utils import deprecate
class SkipBlock(nn.Module):
def __init__(self, dim: int, num_embeds_ada_norm: Optional[int] = None):
def __init__(self, dim: int):
super().__init__()
self.skip_linear = nn.Linear(2 * dim, dim)
# Use AdaLayerNorm for now, maybe support using other forms of LayerNorm?
# Original code uses torch.nn.LayerNorm
self.norm = AdaLayerNorm(dim, num_embeds_ada_norm)
# Use torch.nn.LayerNorm for now, following the original code
self.norm = nn.LayerNorm(dim)
def forward(self, x, skip):
x = self.skip_linear(torch.cat([x, skip], dim=-1))
@@ -29,8 +28,216 @@ class SkipBlock(nn.Module):
return x
# Modified to support both pre-LayerNorm and post-LayerNorm configurations
# Don't support AdaLayerNormZero for now
# Modified from diffusers.models.attention.BasicTransformerBlock
class UTransformerBlock(nn.Module):
r"""
A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
pre_layer_norm: bool = True,
final_dropout: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
# self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.pre_layer_norm = pre_layer_norm
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# 1. Self-Attn
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.attn2 = None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
# elif self.use_ada_layer_norm_zero:
# self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
else:
self.norm2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
):
# Pre-LayerNorm
if self.pre_layer_norm:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
# elif self.use_ada_layer_norm_zero:
# norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
# hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
# )
else:
norm_hidden_states = self.norm1(hidden_states)
else:
norm_hidden_states = hidden_states
# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if not self.pre_layer_norm:
# Post-LayerNorm
if self.use_ada_layer_norm:
attn_output = self.norm1(attn_output, timestep)
# elif self.use_ada_layer_norm_zero:
# attn_output, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
# attn_output, timestep, class_labels, hidden_dtype=attn_output.dtype
# )
else:
attn_output = self.norm1(hidden_states)
# else:
# # Pre-LayerNorm post-processing
# if self.use_ada_layer_norm_zero:
# attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
# Pre-LayerNorm
if self.pre_layer_norm:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
else:
norm_hidden_states = hidden_states
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
# Post-LayerNorm
if not self.pre_layer_norm:
attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
# Pre-LayerNorm
if self.pre_layer_norm:
norm_hidden_states = self.norm3(hidden_states)
# if self.use_ada_layer_norm_zero:
# norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
else:
norm_hidden_states = hidden_states
ff_output = self.ff(norm_hidden_states)
if not self.pre_layer_norm:
# Post-LayerNorm
ff_output = self.norm3(ff_output)
# if self.use_ada_layer_norm_zero:
# ff_output = ff_output * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
# else:
# # Pre-LayerNorm post-processing
# if self.use_ada_layer_norm_zero:
# ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
# Modified from diffusers.models.transformer_2d.Transformer2DModel
# Modify the transformer block structure to be U-Net like following U-ViT
# Only supports patch-style input currently
# https://github.com/baofff/U-ViT
class UTransformer2DModel(ModelMixin, ConfigMixin):
"""
@@ -92,6 +299,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
pre_layer_norm: bool = False,
norm_elementwise_affine: bool = True,
):
super().__init__()
@@ -167,7 +375,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
# a "U"-shaped fashion (e.g. first in_block to last out_block, etc.).
self.transformer_in_blocks = nn.ModuleList(
[
BasicTransformerBlock(
UTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
@@ -179,13 +387,14 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
pre_layer_norm=pre_layer_norm,
norm_elementwise_affine=norm_elementwise_affine,
)
for d in range(num_layers // 2)
]
)
self.transformer_mid_block = BasicTransformerBlock(
self.transformer_mid_block = UTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
@@ -197,6 +406,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
pre_layer_norm=pre_layer_norm,
norm_elementwise_affine=norm_elementwise_affine,
)
@@ -208,9 +418,8 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
{
"skip": SkipBlock(
inner_dim,
num_embeds_ada_norm=num_embeds_ada_norm,
),
"block": BasicTransformerBlock(
"block": UTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
@@ -222,6 +431,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
pre_layer_norm=pre_layer_norm,
norm_elementwise_affine=norm_elementwise_affine,
),
}
@@ -319,7 +529,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.transformer_mid_block(hidden_states)
# Out ("upsample") blocks
for out_block in self.transformer_in_blocks:
for out_block in self.transformer_out_blocks:
hidden_states = out_block["skip"](hidden_states, skips.pop())
hidden_states = out_block["block"](
hidden_states,
@@ -349,7 +559,8 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
output = F.log_softmax(logits.double(), dim=1).float()
elif self.is_input_patches:
# TODO: cleanup!
conditioning = self.transformer_blocks[0].norm1.emb(
# Use self.transformer_in_blocks for now??
conditioning = self.transformer_in_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
@@ -426,6 +637,7 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
pre_layer_norm: bool = False,
norm_elementwise_affine: bool = True,
):
super().__init__()
@@ -496,6 +708,7 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
pre_layer_norm=pre_layer_norm,
norm_elementwise_affine=norm_elementwise_affine,
)

View File

@@ -183,13 +183,13 @@ class UniDiffuserPipeline(DiffusionPipeline):
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def _infer_mode(self, prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents):
r"""Infer the mode from the inputs to `__call__`."""
prompt_available = (prompt is not None) or (prompt_embeds is not None)
image_available = image is not None
input_available = prompt_available or image_available
prompt_latents_available = prompt_latents is not None
vae_latents_available = vae_latents is not None
clip_latents_available = clip_latents is not None
@@ -214,14 +214,14 @@ class UniDiffuserPipeline(DiffusionPipeline):
else:
# No inputs or latents available
mode = "img"
# Give warnings for ambiguous cases
if self.mode is None and prompt_available and image_available:
logger.warning(
f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually,"
f" defaulting to mode '{mode}'."
)
if self.mode is None and not input_available:
if vae_latents_available != clip_latents_available:
# Exactly one of vae_latents and clip_latents is supplied
@@ -234,23 +234,23 @@ class UniDiffuserPipeline(DiffusionPipeline):
logger.warning(
f"No inputs or latents have been supplied, and mode has not been manually set,"
f" defaulting to mode '{mode}'."
)
)
return mode
# Functions to manually set the mode
def set_text_mode(self):
self.mode = "text"
def set_img_mode(self):
self.mode = "img"
def set_text_to_image_mode(self):
self.mode = "text2img"
def set_image_to_text_mode(self):
self.mode = "img2text"
def set_joint_mode(self):
self.mode = "joint"
@@ -634,7 +634,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_hidden_size, text_dim], dim=1)
img_vae = einops.rearrange(
img_vae, "B (C H W) -> B C H W", C=self.image_encoder_hidden_size, H=latent_height, W=latent_width
img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width
)
img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size)
text = einops.rearrange(text, "B (L D) -> B L D", L=self.text_encoder_seq_len, D=self.text_encoder_hidden_size)
@@ -681,10 +681,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
return x_out
# Classifier-free guidance
# TODO: need to replace this with the appropriate generator logic and randn_tensor
img_vae_T = torch.randn_like(img_vae, device=device)
img_clip_T = torch.randn_like(img_clip, device=device)
text_T = torch.randn_like(prompt_embeds, device=device)
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
t_img_uncond = torch.ones_like(t) * timesteps
t_text_uncond = torch.ones_like(t) * timesteps
@@ -711,8 +710,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
return img_out
# Classifier-free guidance
# TODO: need to replace this with the appropriate generator logic and randn_tensor
text_T = torch.randn_like(prompt_embeds)
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
t_text_uncond = torch.ones_like(t) * timesteps
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
@@ -732,9 +730,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
return text_out
# Classifier-free guidance
# TODO: need to replace this with the appropriate generator logic and randn_tensor
img_vae_T = torch.randn_like(img_vae)
img_clip_T = torch.randn_like(img_clip)
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
t_img_uncond = torch.ones_like(t) * timesteps
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
@@ -761,8 +758,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
img_out = self._combine(img_vae_out, img_clip_out)
return img_out
# Temporarily copied from StableDiffusionPipeline.
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -785,17 +781,18 @@ class UniDiffuserPipeline(DiffusionPipeline):
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if self.mode == "text2img":
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -954,17 +951,19 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
)
else:
# 3.2. Prepare text image latent variables, if necessary
# 3.2. Prepare text latent variables, if input not available
prompt_embeds = self.prepare_text_latents(
batch_size,
self.text_encoder_seq_len,
self.text_encoder_hidden_size,
torch.float32, # Placeholder, need to determine correct thing to do for dtype
torch.float32, # TODO: Placeholder, need to determine correct thing to do for dtype
device,
generator,
prompt_latents,
)
# print(f"Prompt embeds shape: {prompt_embeds.shape}")
# 4. Encode image, if available; otherwise prepare image latents
if mode in ["img2text"]:
# 4.1. Encode images, if available
@@ -993,7 +992,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
generator,
)
else:
# 4.2. Prepare image latent variables, if necessary
# 4.2. Prepare image latent variables, if input not available
# Prepare image VAE latents
image_vae_latents = self.prepare_image_vae_latents(
batch_size * num_images_per_prompt,
@@ -1005,6 +1004,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
generator,
vae_latents,
)
# print(f"Image vae latent shape: {image_vae_latents.shape}")
# Prepare image CLIP latents
image_clip_latents = self.prepare_image_clip_latents(
@@ -1015,6 +1015,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
generator,
clip_latents,
)
# print(f"Image clip latent shape: {image_clip_latents.shape}")
# 5. Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1028,6 +1029,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
elif mode in ["img2text", "text"]:
latents = prompt_embeds
# print(f"Latents shape: {latents.shape}")
# 7. Check that shapes of latents and image match the UNet channels.
# TODO

View File

@@ -1,6 +1,10 @@
import random
import unittest
import numpy as np
import pytest
import torch
from PIL import Image
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
@@ -8,48 +12,55 @@ from transformers import (
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModel,
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer,
)
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
UniDiffuserModel,
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from diffusers.utils import slow
from diffusers.utils import floats_tensor, slow
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ...test_pipelines_common import PipelineTesterMixin
class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = UniDiffuserPipeline
params = None # TODO
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
unet = UniDiffuserModel(
sample_size=16,
num_layers=2,
patch_size=4,
attention_head_dim=8,
text_dim=32,
clip_img_dim=32,
num_attention_heads=2,
attention_head_dim=8,
in_channels=4,
out_channels=8,
attention_bias=True,
activation_fn="gelu-approximate",
out_channels=4,
num_layers=2,
dropout=0.0,
norm_num_groups=32,
attention_bias=False,
sample_size=8,
patch_size=2,
activation_fn="gelu",
num_embeds_ada_norm=1000,
norm_type="ada_norm_zero",
norm_type="layer_norm",
pre_layer_norm=False,
norm_elementwise_affine=False,
text_dim=32, # TODO: needs to line up with CLIPTextConfig
clip_img_dim=32, # TODO: needs to line up with CLIPVisionConfig
)
scheduler = DDIMScheduler()
scheduler = DPMSolverMultistepScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
)
torch.manual_seed(0)
vae = AutoencoderKL(
@@ -77,29 +88,51 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
# TODO: get appropriate testing version for these
text_decoder_tokenizer = GPT2Tokenizer()
text_decoder_model_config = GPT2Config()
text_decoder_model = GPT2LMHeadModel(text_decoder_model_config)
text_decoder = UniDiffuserTextDecoder(
text_decoder_tokenizer,
text_decoder_model,
prefix_length=77, # TODO: fix
image_encoder_config = CLIPVisionConfig(
image_size=32,
patch_size=2,
num_channels=3,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
)
image_encoder = CLIPVisionModel(image_encoder_config)
# From the Stable Diffusion Image Variation pipeline tests
image_processor = CLIPImageProcessor(crop_size=32, size=32)
# image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig()
image_encoder = CLIPVisionModel(image_encoder_config)
# TODO: does this actually work?
image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
# From https://huggingface.co/hf-internal-testing/tiny-random-GPT2Model/blob/main/config.json
text_decoder = UniDiffuserTextDecoder(
prefix_length=77,
prefix_hidden_dim=32,
n_positions=1024,
n_embd=32,
n_layer=5,
n_head=4,
n_inner=37,
activation_function="gelu",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
)
text_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")
components = {
"vae": vae,
"text_encoder": text_encoder,
"text_decoder": text_decoder,
"image_encoder": image_encoder,
"tokenizer": tokenizer,
"image_processor": image_processor,
"clip_tokenizer": tokenizer,
"text_decoder": text_decoder,
"text_tokenizer": text_tokenizer,
"unet": unet,
"scheduler": scheduler,
}
@@ -107,10 +140,136 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return components
def get_dummy_inputs(self, device, seed=0):
pass
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
image = Image.fromarray(np.uint8(image)).convert("RGB")
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "an elephant under the sea",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
def test_unidiffuser_default_case(self):
pass
@pytest.mark.xfail(reason="not finished debugging")
def test_unidiffuser_default_joint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'joint'
unidiffuser_pipe.set_joint_mode()
assert unidiffuser_pipe.mode == "joint"
inputs = self.get_dummy_inputs(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
image = unidiffuser_pipe(**inputs).images
text = unidiffuser_pipe(**inputs).text
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.6646, 0.5723, 0.6812, 0.5742, 0.3872, 0.5137, 0.6206, 0.5986, 0.4983])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
# TODO: need to figure out correct text output
print(text)
@pytest.mark.xfail(reason="haven't begun debugging")
def test_unidiffuser_default_text2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text2img'
unidiffuser_pipe.set_text_to_image_mode()
assert unidiffuser_pipe.mode == "text2img"
inputs = self.get_dummy_inputs(device)
# Delete image for text-conditioned image generation
del inputs["image"]
image = unidiffuser_pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.6641, 0.5718, 0.6807, 0.5747, 0.3870, 0.5132, 0.6206, 0.5986, 0.4980])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@pytest.mark.xfail(reason="haven't begun debugging")
def test_unidiffuser_default_img2text(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img2text'
unidiffuser_pipe.set_image_to_text_mode()
assert unidiffuser_pipe.mode == "img2text"
inputs = self.get_dummy_inputs(device)
# Delete text for image-conditioned text generation
del inputs["prompt"]
text = unidiffuser_pipe(**inputs).text
# TODO: need to figure out correct text output
print(text)
@pytest.mark.xfail(reason="haven't begun debugging")
def test_unidiffuser_default_text(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text'
unidiffuser_pipe.set_text_mode()
assert unidiffuser_pipe.mode == "text"
inputs = self.get_dummy_inputs(device)
# Delete prompt and image for unconditional ("marginal") text generation.
del inputs["prompt"]
del inputs["image"]
text = unidiffuser_pipe(**inputs).text
# TODO: need to figure out correct text output
print(text)
@pytest.mark.xfail(reason="haven't begun debugging")
def test_unidiffuser_default_image(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img'
unidiffuser_pipe.set_image_mode()
assert unidiffuser_pipe.mode == "img"
inputs = self.get_dummy_inputs(device)
# Delete prompt and image for unconditional ("marginal") text generation.
del inputs["prompt"]
del inputs["image"]
image = unidiffuser_pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
# TODO: get expected slice of image output
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.6641, 0.5723, 0.6812, 0.5742, 0.3867, 0.5132, 0.6206, 0.5986, 0.4983])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow