diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f8ac91c0eb..438c9d1958 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -115,6 +115,7 @@ else: AltDiffusionPipeline, AudioLDMPipeline, CycleDiffusionPipeline, + ImageTextPipelineOutput, LDMTextToImagePipeline, PaintByExamplePipeline, SemanticStableDiffusionPipeline, @@ -139,6 +140,9 @@ else: TextToVideoSDPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 421099a6d7..f4431a4356 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -70,6 +70,7 @@ else: from .stable_diffusion_safe import StableDiffusionPipelineSafe from .text_to_video_synthesis import TextToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline + from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py index 3782e9e5ff..097d14dcd9 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -3,7 +3,7 @@ from typing import Optional import numpy as np import torch from torch import nn -from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer +from transformers import GPT2Config, GPT2LMHeadModel from transformers.modeling_utils import ModuleUtilsMixin from ...configuration_utils import ConfigMixin, register_to_config @@ -36,7 +36,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): Parameters: prefix_length (`int`): Max number of prefix tokens that will be supplied to the model. - hidden_dim (`int`, *optional*): + prefix_hidden_dim (`int`, *optional*): Hidden dim of the MLP if we encode the prefix. TODO: add GPT2 config args """ @@ -44,8 +44,12 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): self.prefix_length = prefix_length self.prefix_hidden_dim = prefix_hidden_dim - self.encode_prefix = nn.Linear(n_embd, self.hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity() - self.decode_prefix = nn.Linear(self.hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() + self.encode_prefix = ( + nn.Linear(n_embd, self.prefix_hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity() + ) + self.decode_prefix = ( + nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() + ) gpt_config = GPT2Config( n_positions=n_positions, @@ -105,8 +109,8 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): Args: tokenizer (`GPT2Tokenizer`): Tokenizer of class - [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) - for tokenizing input to the text decoder model. + [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) for + tokenizing input to the text decoder model. features (`torch.Tensor` of shape `(B, L, D)`): Text embedding features to generate captions from. device: diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index 977c534ab4..f1147c56b4 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -6,21 +6,20 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin -from ...models.attention import AdaLayerNorm, BasicTransformerBlock +from ...models.attention import AdaLayerNorm, FeedForward +from ...models.attention_processor import Attention from ...models.embeddings import ImagePositionalEmbeddings, PatchEmbed, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModelOutput -from ...utils import deprecate class SkipBlock(nn.Module): - def __init__(self, dim: int, num_embeds_ada_norm: Optional[int] = None): + def __init__(self, dim: int): super().__init__() self.skip_linear = nn.Linear(2 * dim, dim) - # Use AdaLayerNorm for now, maybe support using other forms of LayerNorm? - # Original code uses torch.nn.LayerNorm - self.norm = AdaLayerNorm(dim, num_embeds_ada_norm) + # Use torch.nn.LayerNorm for now, following the original code + self.norm = nn.LayerNorm(dim) def forward(self, x, skip): x = self.skip_linear(torch.cat([x, skip], dim=-1)) @@ -29,8 +28,216 @@ class SkipBlock(nn.Module): return x +# Modified to support both pre-LayerNorm and post-LayerNorm configurations +# Don't support AdaLayerNormZero for now +# Modified from diffusers.models.attention.BasicTransformerBlock +class UTransformerBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = True, + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + # self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + # elif self.use_ada_layer_norm_zero: + # self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + # elif self.use_ada_layer_norm_zero: + # norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + # hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + # ) + else: + norm_hidden_states = self.norm1(hidden_states) + else: + norm_hidden_states = hidden_states + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if not self.pre_layer_norm: + # Post-LayerNorm + if self.use_ada_layer_norm: + attn_output = self.norm1(attn_output, timestep) + # elif self.use_ada_layer_norm_zero: + # attn_output, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + # attn_output, timestep, class_labels, hidden_dtype=attn_output.dtype + # ) + else: + attn_output = self.norm1(hidden_states) + # else: + # # Pre-LayerNorm post-processing + # if self.use_ada_layer_norm_zero: + # attn_output = gate_msa.unsqueeze(1) * attn_output + + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + else: + norm_hidden_states = hidden_states + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) + + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = self.norm3(hidden_states) + + # if self.use_ada_layer_norm_zero: + # norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + else: + norm_hidden_states = hidden_states + + ff_output = self.ff(norm_hidden_states) + + if not self.pre_layer_norm: + # Post-LayerNorm + ff_output = self.norm3(ff_output) + + # if self.use_ada_layer_norm_zero: + # ff_output = ff_output * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + # else: + # # Pre-LayerNorm post-processing + # if self.use_ada_layer_norm_zero: + # ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + # Modified from diffusers.models.transformer_2d.Transformer2DModel # Modify the transformer block structure to be U-Net like following U-ViT +# Only supports patch-style input currently # https://github.com/baofff/U-ViT class UTransformer2DModel(ModelMixin, ConfigMixin): """ @@ -92,6 +299,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", + pre_layer_norm: bool = False, norm_elementwise_affine: bool = True, ): super().__init__() @@ -167,7 +375,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). self.transformer_in_blocks = nn.ModuleList( [ - BasicTransformerBlock( + UTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, @@ -179,13 +387,14 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, + pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, ) for d in range(num_layers // 2) ] ) - self.transformer_mid_block = BasicTransformerBlock( + self.transformer_mid_block = UTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, @@ -197,6 +406,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, + pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, ) @@ -208,9 +418,8 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): { "skip": SkipBlock( inner_dim, - num_embeds_ada_norm=num_embeds_ada_norm, ), - "block": BasicTransformerBlock( + "block": UTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, @@ -222,6 +431,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, + pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, ), } @@ -319,7 +529,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): hidden_states = self.transformer_mid_block(hidden_states) # Out ("upsample") blocks - for out_block in self.transformer_in_blocks: + for out_block in self.transformer_out_blocks: hidden_states = out_block["skip"](hidden_states, skips.pop()) hidden_states = out_block["block"]( hidden_states, @@ -349,7 +559,8 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): output = F.log_softmax(logits.double(), dim=1).float() elif self.is_input_patches: # TODO: cleanup! - conditioning = self.transformer_blocks[0].norm1.emb( + # Use self.transformer_in_blocks for now?? + conditioning = self.transformer_in_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) @@ -426,6 +637,7 @@ class UniDiffuserModel(ModelMixin, ConfigMixin): only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", + pre_layer_norm: bool = False, norm_elementwise_affine: bool = True, ): super().__init__() @@ -496,6 +708,7 @@ class UniDiffuserModel(ModelMixin, ConfigMixin): only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, + pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, ) diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 5e65824d94..1fb33e3bb8 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -183,13 +183,13 @@ class UniDiffuserPipeline(DiffusionPipeline): if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs - + def _infer_mode(self, prompt, prompt_embeds, image, prompt_latents, vae_latents, clip_latents): r"""Infer the mode from the inputs to `__call__`.""" prompt_available = (prompt is not None) or (prompt_embeds is not None) image_available = image is not None input_available = prompt_available or image_available - + prompt_latents_available = prompt_latents is not None vae_latents_available = vae_latents is not None clip_latents_available = clip_latents is not None @@ -214,14 +214,14 @@ class UniDiffuserPipeline(DiffusionPipeline): else: # No inputs or latents available mode = "img" - + # Give warnings for ambiguous cases if self.mode is None and prompt_available and image_available: logger.warning( f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," f" defaulting to mode '{mode}'." ) - + if self.mode is None and not input_available: if vae_latents_available != clip_latents_available: # Exactly one of vae_latents and clip_latents is supplied @@ -234,23 +234,23 @@ class UniDiffuserPipeline(DiffusionPipeline): logger.warning( f"No inputs or latents have been supplied, and mode has not been manually set," f" defaulting to mode '{mode}'." - ) - + ) + return mode - + # Functions to manually set the mode def set_text_mode(self): self.mode = "text" - + def set_img_mode(self): self.mode = "img" - + def set_text_to_image_mode(self): self.mode = "text2img" - + def set_image_to_text_mode(self): self.mode = "img2text" - + def set_joint_mode(self): self.mode = "joint" @@ -634,7 +634,7 @@ class UniDiffuserPipeline(DiffusionPipeline): img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_hidden_size, text_dim], dim=1) img_vae = einops.rearrange( - img_vae, "B (C H W) -> B C H W", C=self.image_encoder_hidden_size, H=latent_height, W=latent_width + img_vae, "B (C H W) -> B C H W", C=self.num_channels_latents, H=latent_height, W=latent_width ) img_clip = einops.rearrange(img_clip, "B (L D) -> B L D", L=1, D=self.image_encoder_hidden_size) text = einops.rearrange(text, "B (L D) -> B L D", L=self.text_encoder_seq_len, D=self.text_encoder_hidden_size) @@ -681,10 +681,9 @@ class UniDiffuserPipeline(DiffusionPipeline): return x_out # Classifier-free guidance - # TODO: need to replace this with the appropriate generator logic and randn_tensor - img_vae_T = torch.randn_like(img_vae, device=device) - img_clip_T = torch.randn_like(img_clip, device=device) - text_T = torch.randn_like(prompt_embeds, device=device) + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) t_img_uncond = torch.ones_like(t) * timesteps t_text_uncond = torch.ones_like(t) * timesteps @@ -711,8 +710,7 @@ class UniDiffuserPipeline(DiffusionPipeline): return img_out # Classifier-free guidance - # TODO: need to replace this with the appropriate generator logic and randn_tensor - text_T = torch.randn_like(prompt_embeds) + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) t_text_uncond = torch.ones_like(t) * timesteps img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( @@ -732,9 +730,8 @@ class UniDiffuserPipeline(DiffusionPipeline): return text_out # Classifier-free guidance - # TODO: need to replace this with the appropriate generator logic and randn_tensor - img_vae_T = torch.randn_like(img_vae) - img_clip_T = torch.randn_like(img_clip) + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) t_img_uncond = torch.ones_like(t) * timesteps img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( @@ -761,8 +758,7 @@ class UniDiffuserPipeline(DiffusionPipeline): img_out = self._combine(img_vae_out, img_clip_out) return img_out - # Temporarily copied from StableDiffusionPipeline. - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, @@ -785,17 +781,18 @@ class UniDiffuserPipeline(DiffusionPipeline): f" {type(callback_steps)}." ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if self.mode == "text2img": + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -954,17 +951,19 @@ class UniDiffuserPipeline(DiffusionPipeline): negative_prompt_embeds=negative_prompt_embeds, ) else: - # 3.2. Prepare text image latent variables, if necessary + # 3.2. Prepare text latent variables, if input not available prompt_embeds = self.prepare_text_latents( batch_size, self.text_encoder_seq_len, self.text_encoder_hidden_size, - torch.float32, # Placeholder, need to determine correct thing to do for dtype + torch.float32, # TODO: Placeholder, need to determine correct thing to do for dtype device, generator, prompt_latents, ) + # print(f"Prompt embeds shape: {prompt_embeds.shape}") + # 4. Encode image, if available; otherwise prepare image latents if mode in ["img2text"]: # 4.1. Encode images, if available @@ -993,7 +992,7 @@ class UniDiffuserPipeline(DiffusionPipeline): generator, ) else: - # 4.2. Prepare image latent variables, if necessary + # 4.2. Prepare image latent variables, if input not available # Prepare image VAE latents image_vae_latents = self.prepare_image_vae_latents( batch_size * num_images_per_prompt, @@ -1005,6 +1004,7 @@ class UniDiffuserPipeline(DiffusionPipeline): generator, vae_latents, ) + # print(f"Image vae latent shape: {image_vae_latents.shape}") # Prepare image CLIP latents image_clip_latents = self.prepare_image_clip_latents( @@ -1015,6 +1015,7 @@ class UniDiffuserPipeline(DiffusionPipeline): generator, clip_latents, ) + # print(f"Image clip latent shape: {image_clip_latents.shape}") # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1028,6 +1029,8 @@ class UniDiffuserPipeline(DiffusionPipeline): elif mode in ["img2text", "text"]: latents = prompt_embeds + # print(f"Latents shape: {latents.shape}") + # 7. Check that shapes of latents and image match the UNet channels. # TODO diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 9f353f3ff1..c7dbefbf23 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -1,6 +1,10 @@ +import random import unittest +import numpy as np +import pytest import torch +from PIL import Image from transformers import ( CLIPImageProcessor, CLIPTextConfig, @@ -8,48 +12,55 @@ from transformers import ( CLIPTokenizer, CLIPVisionConfig, CLIPVisionModel, - GPT2Config, - GPT2LMHeadModel, GPT2Tokenizer, ) from diffusers import ( AutoencoderKL, - DDIMScheduler, + DPMSolverMultistepScheduler, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder, ) -from diffusers.utils import slow +from diffusers.utils import floats_tensor, slow from diffusers.utils.testing_utils import require_torch_gpu +from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ...test_pipelines_common import PipelineTesterMixin class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = UniDiffuserPipeline - params = None # TODO + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS def get_dummy_components(self): torch.manual_seed(0) unet = UniDiffuserModel( - sample_size=16, - num_layers=2, - patch_size=4, - attention_head_dim=8, + text_dim=32, + clip_img_dim=32, num_attention_heads=2, + attention_head_dim=8, in_channels=4, - out_channels=8, - attention_bias=True, - activation_fn="gelu-approximate", + out_channels=4, + num_layers=2, + dropout=0.0, + norm_num_groups=32, + attention_bias=False, + sample_size=8, + patch_size=2, + activation_fn="gelu", num_embeds_ada_norm=1000, - norm_type="ada_norm_zero", + norm_type="layer_norm", + pre_layer_norm=False, norm_elementwise_affine=False, - text_dim=32, # TODO: needs to line up with CLIPTextConfig - clip_img_dim=32, # TODO: needs to line up with CLIPVisionConfig ) - scheduler = DDIMScheduler() + scheduler = DPMSolverMultistepScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + ) torch.manual_seed(0) vae = AutoencoderKL( @@ -77,29 +88,51 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") torch.manual_seed(0) - # TODO: get appropriate testing version for these - text_decoder_tokenizer = GPT2Tokenizer() - text_decoder_model_config = GPT2Config() - text_decoder_model = GPT2LMHeadModel(text_decoder_model_config) - text_decoder = UniDiffuserTextDecoder( - text_decoder_tokenizer, - text_decoder_model, - prefix_length=77, # TODO: fix + image_encoder_config = CLIPVisionConfig( + image_size=32, + patch_size=2, + num_channels=3, + hidden_size=32, + projection_dim=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, ) + image_encoder = CLIPVisionModel(image_encoder_config) + # From the Stable Diffusion Image Variation pipeline tests + image_processor = CLIPImageProcessor(crop_size=32, size=32) + # image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip") torch.manual_seed(0) - image_encoder_config = CLIPVisionConfig() - image_encoder = CLIPVisionModel(image_encoder_config) - # TODO: does this actually work? - image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip") + # From https://huggingface.co/hf-internal-testing/tiny-random-GPT2Model/blob/main/config.json + text_decoder = UniDiffuserTextDecoder( + prefix_length=77, + prefix_hidden_dim=32, + n_positions=1024, + n_embd=32, + n_layer=5, + n_head=4, + n_inner=37, + activation_function="gelu", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + ) + text_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model") components = { "vae": vae, "text_encoder": text_encoder, - "text_decoder": text_decoder, "image_encoder": image_encoder, - "tokenizer": tokenizer, "image_processor": image_processor, + "clip_tokenizer": tokenizer, + "text_decoder": text_decoder, + "text_tokenizer": text_tokenizer, "unet": unet, "scheduler": scheduler, } @@ -107,10 +140,136 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): return components def get_dummy_inputs(self, device, seed=0): - pass + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB") + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "an elephant under the sea", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs - def test_unidiffuser_default_case(self): - pass + @pytest.mark.xfail(reason="not finished debugging") + def test_unidiffuser_default_joint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + inputs = self.get_dummy_inputs(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + text = unidiffuser_pipe(**inputs).text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.6646, 0.5723, 0.6812, 0.5742, 0.3872, 0.5137, 0.6206, 0.5986, 0.4983]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + # TODO: need to figure out correct text output + print(text) + + @pytest.mark.xfail(reason="haven't begun debugging") + def test_unidiffuser_default_text2img(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs(device) + # Delete image for text-conditioned image generation + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.6641, 0.5718, 0.6807, 0.5747, 0.3870, 0.5132, 0.6206, 0.5986, 0.4980]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + @pytest.mark.xfail(reason="haven't begun debugging") + def test_unidiffuser_default_img2text(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + text = unidiffuser_pipe(**inputs).text + + # TODO: need to figure out correct text output + print(text) + + @pytest.mark.xfail(reason="haven't begun debugging") + def test_unidiffuser_default_text(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text' + unidiffuser_pipe.set_text_mode() + assert unidiffuser_pipe.mode == "text" + + inputs = self.get_dummy_inputs(device) + # Delete prompt and image for unconditional ("marginal") text generation. + del inputs["prompt"] + del inputs["image"] + text = unidiffuser_pipe(**inputs).text + + # TODO: need to figure out correct text output + print(text) + + @pytest.mark.xfail(reason="haven't begun debugging") + def test_unidiffuser_default_image(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img' + unidiffuser_pipe.set_image_mode() + assert unidiffuser_pipe.mode == "img" + + inputs = self.get_dummy_inputs(device) + # Delete prompt and image for unconditional ("marginal") text generation. + del inputs["prompt"] + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + # TODO: get expected slice of image output + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.6641, 0.5723, 0.6812, 0.5742, 0.3867, 0.5132, 0.6206, 0.5986, 0.4983]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @slow