mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add UniDiffuser classes to __init__ files, modify transformer block to support pre- and post-LN, add fast default tests, fix some bugs.
This commit is contained in:
@@ -115,6 +115,7 @@ else:
|
||||
AltDiffusionPipeline,
|
||||
AudioLDMPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
ImageTextPipelineOutput,
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
@@ -139,6 +140,9 @@ else:
|
||||
TextToVideoSDPipeline,
|
||||
UnCLIPImageVariationPipeline,
|
||||
UnCLIPPipeline,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
|
||||
@@ -70,6 +70,7 @@ else:
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
from transformers.modeling_utils import ModuleUtilsMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -36,7 +36,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
Parameters:
|
||||
prefix_length (`int`):
|
||||
Max number of prefix tokens that will be supplied to the model.
|
||||
hidden_dim (`int`, *optional*):
|
||||
prefix_hidden_dim (`int`, *optional*):
|
||||
Hidden dim of the MLP if we encode the prefix.
|
||||
TODO: add GPT2 config args
|
||||
"""
|
||||
@@ -44,8 +44,12 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
|
||||
self.prefix_length = prefix_length
|
||||
self.prefix_hidden_dim = prefix_hidden_dim
|
||||
self.encode_prefix = nn.Linear(n_embd, self.hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity()
|
||||
self.decode_prefix = nn.Linear(self.hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
|
||||
self.encode_prefix = (
|
||||
nn.Linear(n_embd, self.prefix_hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity()
|
||||
)
|
||||
self.decode_prefix = (
|
||||
nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
|
||||
)
|
||||
|
||||
gpt_config = GPT2Config(
|
||||
n_positions=n_positions,
|
||||
@@ -105,8 +109,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
Args:
|
||||
tokenizer (`GPT2Tokenizer`):
|
||||
Tokenizer of class
|
||||
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer)
|
||||
for tokenizing input to the text decoder model.
|
||||
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for
|
||||
tokenizing input to the text decoder model.
|
||||
features (`torch.Tensor` of shape `(B, L, D)`):
|
||||
Text embedding features to generate captions from.
|
||||
device:
|
||||
|
||||
@@ -6,21 +6,20 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...models.attention import AdaLayerNorm, BasicTransformerBlock
|
||||
from ...models.attention import AdaLayerNorm, FeedForward
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.embeddings import ImagePositionalEmbeddings, PatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ...models.transformer_2d import Transformer2DModelOutput
|
||||
from ...utils import deprecate
|
||||
|
||||
|
||||
class SkipBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_embeds_ada_norm: Optional[int] = None):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.skip_linear = nn.Linear(2 * dim, dim)
|
||||
|
||||
# Use AdaLayerNorm for now, maybe support using other forms of LayerNorm?
|
||||
# Original code uses torch.nn.LayerNorm
|
||||
self.norm = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
# Use torch.nn.LayerNorm for now, following the original code
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x, skip):
|
||||
x = self.skip_linear(torch.cat([x, skip], dim=-1))
|
||||
@@ -29,8 +28,216 @@ class SkipBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Modified to support both pre-LayerNorm and post-LayerNorm configurations
|
||||
# Don't support AdaLayerNormZero for now
|
||||
# Modified from diffusers.models.attention.BasicTransformerBlock
|
||||
class UTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
pre_layer_norm: bool = True,
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
|
||||
self.pre_layer_norm = pre_layer_norm
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
# 1. Self-Attn
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
# elif self.use_ada_layer_norm_zero:
|
||||
# self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = (
|
||||
AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
if self.use_ada_layer_norm
|
||||
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
)
|
||||
else:
|
||||
self.norm2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
timestep=None,
|
||||
cross_attention_kwargs=None,
|
||||
class_labels=None,
|
||||
):
|
||||
# Pre-LayerNorm
|
||||
if self.pre_layer_norm:
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
# elif self.use_ada_layer_norm_zero:
|
||||
# norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
# hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
# )
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
else:
|
||||
norm_hidden_states = hidden_states
|
||||
|
||||
# 1. Self-Attention
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if not self.pre_layer_norm:
|
||||
# Post-LayerNorm
|
||||
if self.use_ada_layer_norm:
|
||||
attn_output = self.norm1(attn_output, timestep)
|
||||
# elif self.use_ada_layer_norm_zero:
|
||||
# attn_output, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
# attn_output, timestep, class_labels, hidden_dtype=attn_output.dtype
|
||||
# )
|
||||
else:
|
||||
attn_output = self.norm1(hidden_states)
|
||||
# else:
|
||||
# # Pre-LayerNorm post-processing
|
||||
# if self.use_ada_layer_norm_zero:
|
||||
# attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.attn2 is not None:
|
||||
# Pre-LayerNorm
|
||||
if self.pre_layer_norm:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = hidden_states
|
||||
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
|
||||
# prepare attention mask here
|
||||
|
||||
# 2. Cross-Attention
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
# Post-LayerNorm
|
||||
if not self.pre_layer_norm:
|
||||
attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
# Pre-LayerNorm
|
||||
if self.pre_layer_norm:
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
# if self.use_ada_layer_norm_zero:
|
||||
# norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
else:
|
||||
norm_hidden_states = hidden_states
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if not self.pre_layer_norm:
|
||||
# Post-LayerNorm
|
||||
ff_output = self.norm3(ff_output)
|
||||
|
||||
# if self.use_ada_layer_norm_zero:
|
||||
# ff_output = ff_output * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
# else:
|
||||
# # Pre-LayerNorm post-processing
|
||||
# if self.use_ada_layer_norm_zero:
|
||||
# ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Modified from diffusers.models.transformer_2d.Transformer2DModel
|
||||
# Modify the transformer block structure to be U-Net like following U-ViT
|
||||
# Only supports patch-style input currently
|
||||
# https://github.com/baofff/U-ViT
|
||||
class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
@@ -92,6 +299,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
pre_layer_norm: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -167,7 +375,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
# a "U"-shaped fashion (e.g. first in_block to last out_block, etc.).
|
||||
self.transformer_in_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
UTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
@@ -179,13 +387,14 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
pre_layer_norm=pre_layer_norm,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
)
|
||||
for d in range(num_layers // 2)
|
||||
]
|
||||
)
|
||||
|
||||
self.transformer_mid_block = BasicTransformerBlock(
|
||||
self.transformer_mid_block = UTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
@@ -197,6 +406,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
pre_layer_norm=pre_layer_norm,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
)
|
||||
|
||||
@@ -208,9 +418,8 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
{
|
||||
"skip": SkipBlock(
|
||||
inner_dim,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
),
|
||||
"block": BasicTransformerBlock(
|
||||
"block": UTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
@@ -222,6 +431,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
pre_layer_norm=pre_layer_norm,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
),
|
||||
}
|
||||
@@ -319,7 +529,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.transformer_mid_block(hidden_states)
|
||||
|
||||
# Out ("upsample") blocks
|
||||
for out_block in self.transformer_in_blocks:
|
||||
for out_block in self.transformer_out_blocks:
|
||||
hidden_states = out_block["skip"](hidden_states, skips.pop())
|
||||
hidden_states = out_block["block"](
|
||||
hidden_states,
|
||||
@@ -349,7 +559,8 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
elif self.is_input_patches:
|
||||
# TODO: cleanup!
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
# Use self.transformer_in_blocks for now??
|
||||
conditioning = self.transformer_in_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
@@ -426,6 +637,7 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
pre_layer_norm: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -496,6 +708,7 @@ class UniDiffuserModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
pre_layer_norm=pre_layer_norm,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
)
|
||||
|
||||
|
||||
@@ -183,13 +183,13 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
|
||||
def _infer_mode(self, prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents):
|
||||
r"""Infer the mode from the inputs to `__call__`."""
|
||||
prompt_available = (prompt is not None) or (prompt_embeds is not None)
|
||||
image_available = image is not None
|
||||
input_available = prompt_available or image_available
|
||||
|
||||
|
||||
prompt_latents_available = prompt_latents is not None
|
||||
vae_latents_available = vae_latents is not None
|
||||
clip_latents_available = clip_latents is not None
|
||||
@@ -214,14 +214,14 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
else:
|
||||
# No inputs or latents available
|
||||
mode = "img"
|
||||
|
||||
|
||||
# Give warnings for ambiguous cases
|
||||
if self.mode is None and prompt_available and image_available:
|
||||
logger.warning(
|
||||
f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually,"
|
||||
f" defaulting to mode '{mode}'."
|
||||
)
|
||||
|
||||
|
||||
if self.mode is None and not input_available:
|
||||
if vae_latents_available != clip_latents_available:
|
||||
# Exactly one of vae_latents and clip_latents is supplied
|
||||
@@ -234,23 +234,23 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
logger.warning(
|
||||
f"No inputs or latents have been supplied, and mode has not been manually set,"
|
||||
f" defaulting to mode '{mode}'."
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
return mode
|
||||
|
||||
|
||||
# Functions to manually set the mode
|
||||
def set_text_mode(self):
|
||||
self.mode = "text"
|
||||
|
||||
|
||||
def set_img_mode(self):
|
||||
self.mode = "img"
|
||||
|
||||
|
||||
def set_text_to_image_mode(self):
|
||||
self.mode = "text2img"
|
||||
|
||||
|
||||
def set_image_to_text_mode(self):
|
||||
self.mode = "img2text"
|
||||
|
||||
|
||||
def set_joint_mode(self):
|
||||
self.mode = "joint"
|
||||
|
||||
@@ -634,7 +634,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_hidden_size, text_dim], dim=1)
|
||||
|
||||
img_vae = einops.rearrange(
|
||||
img_vae, "B (C H W) -> B C H W", C=self.image_encoder_hidden_size, H=latent_height, W=latent_width
|
||||
img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width
|
||||
)
|
||||
img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size)
|
||||
text = einops.rearrange(text, "B (L D) -> B L D", L=self.text_encoder_seq_len, D=self.text_encoder_hidden_size)
|
||||
@@ -681,10 +681,9 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
return x_out
|
||||
|
||||
# Classifier-free guidance
|
||||
# TODO: need to replace this with the appropriate generator logic and randn_tensor
|
||||
img_vae_T = torch.randn_like(img_vae, device=device)
|
||||
img_clip_T = torch.randn_like(img_clip, device=device)
|
||||
text_T = torch.randn_like(prompt_embeds, device=device)
|
||||
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
|
||||
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
|
||||
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
|
||||
t_img_uncond = torch.ones_like(t) * timesteps
|
||||
t_text_uncond = torch.ones_like(t) * timesteps
|
||||
|
||||
@@ -711,8 +710,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
return img_out
|
||||
|
||||
# Classifier-free guidance
|
||||
# TODO: need to replace this with the appropriate generator logic and randn_tensor
|
||||
text_T = torch.randn_like(prompt_embeds)
|
||||
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
|
||||
t_text_uncond = torch.ones_like(t) * timesteps
|
||||
|
||||
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
|
||||
@@ -732,9 +730,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
return text_out
|
||||
|
||||
# Classifier-free guidance
|
||||
# TODO: need to replace this with the appropriate generator logic and randn_tensor
|
||||
img_vae_T = torch.randn_like(img_vae)
|
||||
img_clip_T = torch.randn_like(img_clip)
|
||||
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
|
||||
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
|
||||
t_img_uncond = torch.ones_like(t) * timesteps
|
||||
|
||||
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
|
||||
@@ -761,8 +758,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
img_out = self._combine(img_vae_out, img_clip_out)
|
||||
return img_out
|
||||
|
||||
# Temporarily copied from StableDiffusionPipeline.
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -785,17 +781,18 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
if self.mode == "text2img":
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -954,17 +951,19 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
else:
|
||||
# 3.2. Prepare text image latent variables, if necessary
|
||||
# 3.2. Prepare text latent variables, if input not available
|
||||
prompt_embeds = self.prepare_text_latents(
|
||||
batch_size,
|
||||
self.text_encoder_seq_len,
|
||||
self.text_encoder_hidden_size,
|
||||
torch.float32, # Placeholder, need to determine correct thing to do for dtype
|
||||
torch.float32, # TODO: Placeholder, need to determine correct thing to do for dtype
|
||||
device,
|
||||
generator,
|
||||
prompt_latents,
|
||||
)
|
||||
|
||||
# print(f"Prompt embeds shape: {prompt_embeds.shape}")
|
||||
|
||||
# 4. Encode image, if available; otherwise prepare image latents
|
||||
if mode in ["img2text"]:
|
||||
# 4.1. Encode images, if available
|
||||
@@ -993,7 +992,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
generator,
|
||||
)
|
||||
else:
|
||||
# 4.2. Prepare image latent variables, if necessary
|
||||
# 4.2. Prepare image latent variables, if input not available
|
||||
# Prepare image VAE latents
|
||||
image_vae_latents = self.prepare_image_vae_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
@@ -1005,6 +1004,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
generator,
|
||||
vae_latents,
|
||||
)
|
||||
# print(f"Image vae latent shape: {image_vae_latents.shape}")
|
||||
|
||||
# Prepare image CLIP latents
|
||||
image_clip_latents = self.prepare_image_clip_latents(
|
||||
@@ -1015,6 +1015,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
generator,
|
||||
clip_latents,
|
||||
)
|
||||
# print(f"Image clip latent shape: {image_clip_latents.shape}")
|
||||
|
||||
# 5. Set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -1028,6 +1029,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
elif mode in ["img2text", "text"]:
|
||||
latents = prompt_embeds
|
||||
|
||||
# print(f"Latents shape: {latents.shape}")
|
||||
|
||||
# 7. Check that shapes of latents and image match the UNet channels.
|
||||
# TODO
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextConfig,
|
||||
@@ -8,48 +12,55 @@ from transformers import (
|
||||
CLIPTokenizer,
|
||||
CLIPVisionConfig,
|
||||
CLIPVisionModel,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from diffusers.utils import slow
|
||||
from diffusers.utils import floats_tensor, slow
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
|
||||
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = UniDiffuserPipeline
|
||||
params = None # TODO
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UniDiffuserModel(
|
||||
sample_size=16,
|
||||
num_layers=2,
|
||||
patch_size=4,
|
||||
attention_head_dim=8,
|
||||
text_dim=32,
|
||||
clip_img_dim=32,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
in_channels=4,
|
||||
out_channels=8,
|
||||
attention_bias=True,
|
||||
activation_fn="gelu-approximate",
|
||||
out_channels=4,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
norm_num_groups=32,
|
||||
attention_bias=False,
|
||||
sample_size=8,
|
||||
patch_size=2,
|
||||
activation_fn="gelu",
|
||||
num_embeds_ada_norm=1000,
|
||||
norm_type="ada_norm_zero",
|
||||
norm_type="layer_norm",
|
||||
pre_layer_norm=False,
|
||||
norm_elementwise_affine=False,
|
||||
text_dim=32, # TODO: needs to line up with CLIPTextConfig
|
||||
clip_img_dim=32, # TODO: needs to line up with CLIPVisionConfig
|
||||
)
|
||||
|
||||
scheduler = DDIMScheduler()
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
@@ -77,29 +88,51 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
torch.manual_seed(0)
|
||||
# TODO: get appropriate testing version for these
|
||||
text_decoder_tokenizer = GPT2Tokenizer()
|
||||
text_decoder_model_config = GPT2Config()
|
||||
text_decoder_model = GPT2LMHeadModel(text_decoder_model_config)
|
||||
text_decoder = UniDiffuserTextDecoder(
|
||||
text_decoder_tokenizer,
|
||||
text_decoder_model,
|
||||
prefix_length=77, # TODO: fix
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
image_size=32,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
initializer_range=0.02,
|
||||
)
|
||||
image_encoder = CLIPVisionModel(image_encoder_config)
|
||||
# From the Stable Diffusion Image Variation pipeline tests
|
||||
image_processor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
# image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig()
|
||||
image_encoder = CLIPVisionModel(image_encoder_config)
|
||||
# TODO: does this actually work?
|
||||
image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
# From https://huggingface.co/hf-internal-testing/tiny-random-GPT2Model/blob/main/config.json
|
||||
text_decoder = UniDiffuserTextDecoder(
|
||||
prefix_length=77,
|
||||
prefix_hidden_dim=32,
|
||||
n_positions=1024,
|
||||
n_embd=32,
|
||||
n_layer=5,
|
||||
n_head=4,
|
||||
n_inner=37,
|
||||
activation_function="gelu",
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
)
|
||||
text_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")
|
||||
|
||||
components = {
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"text_decoder": text_decoder,
|
||||
"image_encoder": image_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"image_processor": image_processor,
|
||||
"clip_tokenizer": tokenizer,
|
||||
"text_decoder": text_decoder,
|
||||
"text_tokenizer": text_tokenizer,
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
@@ -107,10 +140,136 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
pass
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "an elephant under the sea",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_unidiffuser_default_case(self):
|
||||
pass
|
||||
@pytest.mark.xfail(reason="not finished debugging")
|
||||
def test_unidiffuser_default_joint(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
unidiffuser_pipe = UniDiffuserPipeline(**components)
|
||||
unidiffuser_pipe = unidiffuser_pipe.to(device)
|
||||
unidiffuser_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Set mode to 'joint'
|
||||
unidiffuser_pipe.set_joint_mode()
|
||||
assert unidiffuser_pipe.mode == "joint"
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Delete prompt and image for joint inference.
|
||||
del inputs["prompt"]
|
||||
del inputs["image"]
|
||||
image = unidiffuser_pipe(**inputs).images
|
||||
text = unidiffuser_pipe(**inputs).text
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.6646, 0.5723, 0.6812, 0.5742, 0.3872, 0.5137, 0.6206, 0.5986, 0.4983])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
# TODO: need to figure out correct text output
|
||||
print(text)
|
||||
|
||||
@pytest.mark.xfail(reason="haven't begun debugging")
|
||||
def test_unidiffuser_default_text2img(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
unidiffuser_pipe = UniDiffuserPipeline(**components)
|
||||
unidiffuser_pipe = unidiffuser_pipe.to(device)
|
||||
unidiffuser_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Set mode to 'text2img'
|
||||
unidiffuser_pipe.set_text_to_image_mode()
|
||||
assert unidiffuser_pipe.mode == "text2img"
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Delete image for text-conditioned image generation
|
||||
del inputs["image"]
|
||||
image = unidiffuser_pipe(**inputs).images
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.6641, 0.5718, 0.6807, 0.5747, 0.3870, 0.5132, 0.6206, 0.5986, 0.4980])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@pytest.mark.xfail(reason="haven't begun debugging")
|
||||
def test_unidiffuser_default_img2text(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
unidiffuser_pipe = UniDiffuserPipeline(**components)
|
||||
unidiffuser_pipe = unidiffuser_pipe.to(device)
|
||||
unidiffuser_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Set mode to 'img2text'
|
||||
unidiffuser_pipe.set_image_to_text_mode()
|
||||
assert unidiffuser_pipe.mode == "img2text"
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Delete text for image-conditioned text generation
|
||||
del inputs["prompt"]
|
||||
text = unidiffuser_pipe(**inputs).text
|
||||
|
||||
# TODO: need to figure out correct text output
|
||||
print(text)
|
||||
|
||||
@pytest.mark.xfail(reason="haven't begun debugging")
|
||||
def test_unidiffuser_default_text(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
unidiffuser_pipe = UniDiffuserPipeline(**components)
|
||||
unidiffuser_pipe = unidiffuser_pipe.to(device)
|
||||
unidiffuser_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Set mode to 'text'
|
||||
unidiffuser_pipe.set_text_mode()
|
||||
assert unidiffuser_pipe.mode == "text"
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Delete prompt and image for unconditional ("marginal") text generation.
|
||||
del inputs["prompt"]
|
||||
del inputs["image"]
|
||||
text = unidiffuser_pipe(**inputs).text
|
||||
|
||||
# TODO: need to figure out correct text output
|
||||
print(text)
|
||||
|
||||
@pytest.mark.xfail(reason="haven't begun debugging")
|
||||
def test_unidiffuser_default_image(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
unidiffuser_pipe = UniDiffuserPipeline(**components)
|
||||
unidiffuser_pipe = unidiffuser_pipe.to(device)
|
||||
unidiffuser_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Set mode to 'img'
|
||||
unidiffuser_pipe.set_image_mode()
|
||||
assert unidiffuser_pipe.mode == "img"
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Delete prompt and image for unconditional ("marginal") text generation.
|
||||
del inputs["prompt"]
|
||||
del inputs["image"]
|
||||
image = unidiffuser_pipe(**inputs).images
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
# TODO: get expected slice of image output
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.6641, 0.5723, 0.6812, 0.5742, 0.3867, 0.5132, 0.6206, 0.5986, 0.4983])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user