diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f551c05297..982ba499da 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -234,6 +234,10 @@ title: ConsistencyDecoderVAE - local: api/models/transformer2d title: Transformer2D + - local: api/models/pixart_transformer2d + title: PixArtTransformer2D + - local: api/models/dit_transformer2d + title: DiTTransformer2D - local: api/models/transformer_temporal title: Transformer Temporal - local: api/models/prior_transformer diff --git a/docs/source/en/api/models/dit_transformer2d.md b/docs/source/en/api/models/dit_transformer2d.md new file mode 100644 index 0000000000..1bf48e3da9 --- /dev/null +++ b/docs/source/en/api/models/dit_transformer2d.md @@ -0,0 +1,19 @@ + + +# DiTTransformer2D + +A Transformer model for image-like data from [DiT](https://huggingface.co/papers/2212.09748). + +## DiTTransformer2DModel + +[[autodoc]] DiTTransformer2DModel diff --git a/docs/source/en/api/models/pixart_transformer2d.md b/docs/source/en/api/models/pixart_transformer2d.md new file mode 100644 index 0000000000..982122207a --- /dev/null +++ b/docs/source/en/api/models/pixart_transformer2d.md @@ -0,0 +1,19 @@ + + +# PixArtTransformer2D + +A Transformer model for image-like data from [PixArt-Alpha](https://huggingface.co/papers/2310.00426) and [PixArt-Sigma](https://huggingface.co/papers/2403.04692). + +## PixArtTransformer2DModel + +[[autodoc]] PixArtTransformer2DModel diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d510f76ebf..2897a8371b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -82,11 +82,13 @@ else: "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetXSAdapter", + "DiTTransformer2DModel", "I2VGenXLUNet", "Kandinsky3UNet", "ModelMixin", "MotionAdapter", "MultiAdapter", + "PixArtTransformer2DModel", "PriorTransformer", "StableCascadeUNet", "T2IAdapter", @@ -484,11 +486,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ConsistencyDecoderVAE, ControlNetModel, ControlNetXSAdapter, + DiTTransformer2DModel, I2VGenXLUNet, Kandinsky3UNet, ModelMixin, MotionAdapter, MultiAdapter, + PixArtTransformer2DModel, PriorTransformer, T2IAdapter, T5FilmDecoder, diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index d601ce2356..be74ae0619 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -706,3 +706,20 @@ def flax_register_to_config(cls): cls.__init__ = init return cls + + +class LegacyConfigMixin(ConfigMixin): + r""" + A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): + # To prevent depedency import problem. + from .models.model_loading_utils import _fetch_remapped_cls_from_config + + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_config(config, return_unused_kwargs, **kwargs) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 78b0efff92..de5f3d53c3 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -36,6 +36,8 @@ if is_torch_available(): _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] + _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] @@ -73,7 +75,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( + DiTTransformer2DModel, DualTransformer2DModel, + PixArtTransformer2DModel, PriorTransformer, T5FilmDecoder, Transformer2DModel, diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 635cd0ba57..516557e2df 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect import os from collections import OrderedDict @@ -32,6 +33,13 @@ from ..utils import ( logger = logging.get_logger(__name__) +_CLASS_REMAPPING_DICT = { + "Transformer2DModel": { + "ada_norm_zero": "DiTTransformer2DModel", + "ada_norm_single": "PixArtTransformer2DModel", + } +} + if is_accelerate_available(): from accelerate import infer_auto_device_map @@ -61,6 +69,22 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_ return device_map +def _fetch_remapped_cls_from_config(config, old_class): + previous_class_name = old_class.__name__ + remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"]) + + # load diffusers library to import compatible and original scheduler + diffusers_library = importlib.import_module(__name__.split(".")[0]) + remapped_class = getattr(diffusers_library, remapped_class_name) + logger.info( + f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type." + "This is because `previous_class_name` is scheduled to be deprecated in a future version. Note that this" + " DOESN'T affect the final results." + ) + + return remapped_class + + def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. diff --git a/src/diffusers/models/modeling_outputs.py b/src/diffusers/models/modeling_outputs.py index 8dfee5fec1..0120a34d90 100644 --- a/src/diffusers/models/modeling_outputs.py +++ b/src/diffusers/models/modeling_outputs.py @@ -15,3 +15,17 @@ class AutoencoderKLOutput(BaseOutput): """ latent_dist: "DiagonalGaussianDistribution" # noqa: F821 + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: "torch.Tensor" # noqa: F821 diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 2ed5655c84..744972cde0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -42,7 +42,11 @@ from ..utils import ( is_torch_version, logging, ) -from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card +from ..utils.hub_utils import ( + PushToHubMixin, + load_or_create_model_card, + populate_model_card, +) from .model_loading_utils import ( _determine_device_map, _load_state_dict_into_model, @@ -1039,3 +1043,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): del module.key del module.value del module.proj_attn + + +class LegacyModelMixin(ModelMixin): + r""" + A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + # To prevent depedency import problem. + from .model_loading_utils import _fetch_remapped_cls_from_config + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, _, _ = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index dc78a72b2f..f668bc6642 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -2,7 +2,9 @@ from ...utils import is_torch_available if is_torch_available(): + from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel + from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py new file mode 100644 index 0000000000..9f8957737d --- /dev/null +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -0,0 +1,240 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ..attention import BasicTransformerBlock +from ..embeddings import PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748). + + Parameters: + num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (int, optional, defaults to 72): The number of channels in each head. + in_channels (int, defaults to 4): The number of channels in the input. + out_channels (int, optional): + The number of channels in the output. Specify this parameter if the output channel number differs from the + input. + num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. + dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. + norm_num_groups (int, optional, defaults to 32): + Number of groups for group normalization within Transformer blocks. + attention_bias (bool, optional, defaults to True): + Configure if the Transformer blocks' attention should contain a bias parameter. + sample_size (int, defaults to 32): + The width of the latent images. This parameter is fixed during training. + patch_size (int, defaults to 2): + Size of the patches the model processes, relevant for architectures working on non-sequential data. + activation_fn (str, optional, defaults to "gelu-approximate"): + Activation function to use in feed-forward networks within Transformer blocks. + num_embeds_ada_norm (int, optional, defaults to 1000): + Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during + inference. + upcast_attention (bool, optional, defaults to False): + If true, upcasts the attention mechanism dimensions for potentially improved performance. + norm_type (str, optional, defaults to "ada_norm_zero"): + Specifies the type of normalization used, can be 'ada_norm_zero'. + norm_elementwise_affine (bool, optional, defaults to False): + If true, enables element-wise affine parameters in the normalization layers. + norm_eps (float, optional, defaults to 1e-5): + A small constant added to the denominator in normalization layers to prevent division by zero. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + in_channels: int = 4, + out_channels: Optional[int] = None, + num_layers: int = 28, + dropout: float = 0.0, + norm_num_groups: int = 32, + attention_bias: bool = True, + sample_size: int = 32, + patch_size: int = 2, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_zero", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + ): + super().__init__() + + # Validate inputs. + if norm_type != "ada_norm_zero": + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + # 2. Initialize the position embedding and transformer blocks. + self.height = self.config.sample_size + self.width = self.config.sample_size + + self.patch_size = self.config.patch_size + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict: bool = True, + ): + """ + The [`DiTTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + None, + None, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py new file mode 100644 index 0000000000..9c8f9b0908 --- /dev/null +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -0,0 +1,336 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ..attention import BasicTransformerBlock +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PixArtTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, + https://arxiv.org/abs/2403.04692). + + Parameters: + num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (int, optional, defaults to 72): The number of channels in each head. + in_channels (int, defaults to 4): The number of channels in the input. + out_channels (int, optional): + The number of channels in the output. Specify this parameter if the output channel number differs from the + input. + num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. + dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. + norm_num_groups (int, optional, defaults to 32): + Number of groups for group normalization within Transformer blocks. + cross_attention_dim (int, optional): + The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension. + attention_bias (bool, optional, defaults to True): + Configure if the Transformer blocks' attention should contain a bias parameter. + sample_size (int, defaults to 128): + The width of the latent images. This parameter is fixed during training. + patch_size (int, defaults to 2): + Size of the patches the model processes, relevant for architectures working on non-sequential data. + activation_fn (str, optional, defaults to "gelu-approximate"): + Activation function to use in feed-forward networks within Transformer blocks. + num_embeds_ada_norm (int, optional, defaults to 1000): + Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during + inference. + upcast_attention (bool, optional, defaults to False): + If true, upcasts the attention mechanism dimensions for potentially improved performance. + norm_type (str, optional, defaults to "ada_norm_zero"): + Specifies the type of normalization used, can be 'ada_norm_zero'. + norm_elementwise_affine (bool, optional, defaults to False): + If true, enables element-wise affine parameters in the normalization layers. + norm_eps (float, optional, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings. + use_additional_conditions (bool, optional): If we're using additional conditions as inputs. + attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used. + caption_channels (int, optional, defaults to None): + Number of channels to use for projecting the caption embeddings. + use_linear_projection (bool, optional, defaults to False): + Deprecated argument. Will be removed in a future version. + num_vector_embeds (bool, optional, defaults to False): + Deprecated argument. Will be removed in a future version. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + in_channels: int = 4, + out_channels: Optional[int] = 8, + num_layers: int = 28, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = 1152, + attention_bias: bool = True, + sample_size: int = 128, + patch_size: int = 2, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_single", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + interpolation_scale: Optional[int] = None, + use_additional_conditions: Optional[bool] = None, + caption_channels: Optional[int] = None, + attention_type: Optional[str] = "default", + ): + super().__init__() + + # Validate inputs. + if norm_type != "ada_norm_single": + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + if use_additional_conditions is None: + if sample_size == 128: + use_additional_conditions = True + else: + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + self.gradient_checkpointing = False + + # 2. Initialize the position embedding and transformer blocks. + self.height = self.config.sample_size + self.width = self.config.sample_size + + interpolation_scale = ( + self.config.interpolation_scale + if self.config.interpolation_scale is not None + else max(self.config.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + self.caption_projection = None + if self.config.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.config.caption_channels, hidden_size=self.inner_dim + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`PixArtTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep (`torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) + hidden_states = self.pos_embed(hidden_states) + + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + None, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=None, + ) + + # 3. Output + shift, scale = ( + self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index ef9e0de0b6..70f2337290 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -11,39 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput, deprecate, is_torch_version, logging +from ...configuration_utils import LegacyConfigMixin, register_to_config +from ...utils import deprecate, is_torch_version, logging from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection -from ..modeling_utils import ModelMixin +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import LegacyModelMixin from ..normalization import AdaLayerNormSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class Transformer2DModelOutput(BaseOutput): - """ - The output of [`Transformer2DModel`]. - - Args: - sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - """ - - sample: torch.Tensor +class Transformer2DModelOutput(Transformer2DModelOutput): + deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." + deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) -class Transformer2DModel(ModelMixin, ConfigMixin): +class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): """ A 2D Transformer model for image-like data. @@ -116,40 +107,12 @@ class Transformer2DModel(ModelMixin, ConfigMixin): f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." ) - # Set some common variables used across the board. - self.use_linear_projection = use_linear_projection - self.interpolation_scale = interpolation_scale - self.caption_channels = caption_channels - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.gradient_checkpointing = False - if use_additional_conditions is None: - if norm_type == "ada_norm_single" and sample_size == 128: - use_additional_conditions = True - else: - use_additional_conditions = False - self.use_additional_conditions = use_additional_conditions - # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None self.is_input_patches = in_channels is not None and patch_size is not None - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" @@ -166,6 +129,35 @@ class Transformer2DModel(ModelMixin, ConfigMixin): f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # Set some common variables used across the board. + self.use_linear_projection = use_linear_projection + self.interpolation_scale = interpolation_scale + self.caption_channels = caption_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + if use_additional_conditions is None: + if norm_type == "ada_norm_single" and sample_size == 128: + use_additional_conditions = True + else: + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + # 2. Initialize the right blocks. # These functions follow a common structure: # a. Initialize the input blocks. b. Initialize the transformer blocks. diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py index a3ea90874a..14321b5f33 100644 --- a/src/diffusers/pipelines/dit/pipeline_dit.py +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from ...models import AutoencoderKL, Transformer2DModel +from ...models import AutoencoderKL, DiTTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -36,8 +36,8 @@ class DiTPipeline(DiffusionPipeline): implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: - transformer ([`Transformer2DModel`]): - A class conditioned `Transformer2DModel` to denoise the encoded image latents. + transformer ([`DiTTransformer2DModel`]): + A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. scheduler ([`DDIMScheduler`]): @@ -48,7 +48,7 @@ class DiTPipeline(DiffusionPipeline): def __init__( self, - transformer: Transformer2DModel, + transformer: DiTTransformer2DModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, id2label: Optional[Dict[int, str]] = None, diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 941de2b47a..0043bec65d 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -608,6 +608,7 @@ def load_sub_model( cached_folder: Union[str, os.PathLike], ): """Helper method to load the module `name` from `library_name` and `class_name`""" + # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 6d3f5c1e27..6a75090761 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -22,7 +22,7 @@ import torch from transformers import T5EncoderModel, T5Tokenizer from ...image_processor import PixArtImageProcessor -from ...models import AutoencoderKL, Transformer2DModel +from ...models import AutoencoderKL, PixArtTransformer2DModel from ...schedulers import DPMSolverMultistepScheduler from ...utils import ( BACKENDS_MAPPING, @@ -246,8 +246,8 @@ class PixArtAlphaPipeline(DiffusionPipeline): tokenizer (`T5Tokenizer`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - transformer ([`Transformer2DModel`]): - A text conditioned `Transformer2DModel` to denoise the encoded image latents. + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ @@ -276,7 +276,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKL, - transformer: Transformer2DModel, + transformer: PixArtTransformer2DModel, scheduler: DPMSolverMultistepScheduler, ): super().__init__() diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 1db7e5d9ab..9c17757791 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -22,7 +22,7 @@ import torch from transformers import T5EncoderModel, T5Tokenizer from ...image_processor import PixArtImageProcessor -from ...models import AutoencoderKL, Transformer2DModel +from ...models import AutoencoderKL, PixArtTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BACKENDS_MAPPING, @@ -202,7 +202,7 @@ class PixArtSigmaPipeline(DiffusionPipeline): tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKL, - transformer: Transformer2DModel, + transformer: PixArtTransformer2DModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b04006cb5e..d3b79be1cb 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -107,6 +107,21 @@ class ControlNetXSAdapter(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class DiTTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] @@ -182,6 +197,21 @@ class MultiAdapter(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class PixArtTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 59369b5098..a72290b75f 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -559,7 +559,7 @@ class ModelTesterMixin: max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, expected_max_diff) - def test_output(self): + def test_output(self, expected_output_shape=None): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) @@ -575,8 +575,12 @@ class ModelTesterMixin: # input & output have to have the same shape input_tensor = inputs_dict[self.main_input_name] - expected_shape = input_tensor.shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + if expected_output_shape is None: + expected_shape = input_tensor.shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + else: + self.assertEqual(output.shape, expected_output_shape, "Input and output shapes do not match") def test_model_from_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py new file mode 100644 index 0000000000..b12cae1a88 --- /dev/null +++ b/tests/models/transformers/test_models_dit_transformer2d.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import DiTTransformer2DModel, Transformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + slow, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class DiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase): + model_class = DiTTransformer2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 4 + in_channels = 4 + sample_size = 8 + scheduler_num_train_steps = 1000 + num_class_labels = 4 + + hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device) + timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device) + class_label_ids = torch.randint(0, num_class_labels, size=(batch_size,)).to(torch_device) + + return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_label_ids} + + @property + def input_shape(self): + return (4, 8, 8) + + @property + def output_shape(self): + return (8, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 8, + "activation_fn": "gelu-approximate", + "num_attention_heads": 2, + "attention_head_dim": 4, + "attention_bias": True, + "num_layers": 1, + "norm_type": "ada_norm_zero", + "num_embeds_ada_norm": 8, + "patch_size": 2, + "sample_size": 8, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + super().test_output( + expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape + ) + + def test_correct_class_remapping_from_dict_config(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = Transformer2DModel.from_config(init_dict) + assert isinstance(model, DiTTransformer2DModel) + + def test_correct_class_remapping_from_pretrained_config(self): + config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer") + model = Transformer2DModel.from_config(config) + assert isinstance(model, DiTTransformer2DModel) + + @slow + def test_correct_class_remapping(self): + model = Transformer2DModel.from_pretrained("facebook/DiT-XL-2-256", subfolder="transformer") + assert isinstance(model, DiTTransformer2DModel) diff --git a/tests/models/transformers/test_models_pixart_transformer2d.py b/tests/models/transformers/test_models_pixart_transformer2d.py new file mode 100644 index 0000000000..30293f5d35 --- /dev/null +++ b/tests/models/transformers/test_models_pixart_transformer2d.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import PixArtTransformer2DModel, Transformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + slow, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class PixArtTransformer2DModelTests(ModelTesterMixin, unittest.TestCase): + model_class = PixArtTransformer2DModel + main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.7, 0.6, 0.6] + + @property + def dummy_input(self): + batch_size = 4 + in_channels = 4 + sample_size = 8 + scheduler_num_train_steps = 1000 + cross_attention_dim = 8 + seq_len = 8 + + hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device) + timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, seq_len, cross_attention_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "timestep": timesteps, + "encoder_hidden_states": encoder_hidden_states, + "added_cond_kwargs": {"aspect_ratio": None, "resolution": None}, + } + + @property + def input_shape(self): + return (4, 8, 8) + + @property + def output_shape(self): + return (8, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 8, + "num_layers": 1, + "patch_size": 2, + "attention_head_dim": 2, + "num_attention_heads": 2, + "in_channels": 4, + "cross_attention_dim": 8, + "out_channels": 8, + "attention_bias": True, + "activation_fn": "gelu-approximate", + "num_embeds_ada_norm": 8, + "norm_type": "ada_norm_single", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "use_additional_conditions": False, + "caption_channels": None, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + super().test_output( + expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape + ) + + def test_correct_class_remapping_from_dict_config(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = Transformer2DModel.from_config(init_dict) + assert isinstance(model, PixArtTransformer2DModel) + + def test_correct_class_remapping_from_pretrained_config(self): + config = PixArtTransformer2DModel.load_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer") + model = Transformer2DModel.from_config(config) + assert isinstance(model, PixArtTransformer2DModel) + + @slow + def test_correct_class_remapping(self): + model = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer") + assert isinstance(model, PixArtTransformer2DModel) diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py index 937265ab05..30883ac4a6 100644 --- a/tests/pipelines/dit/test_dit.py +++ b/tests/pipelines/dit/test_dit.py @@ -19,7 +19,7 @@ import unittest import numpy as np import torch -from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel +from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler from diffusers.utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device @@ -46,7 +46,7 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) - transformer = Transformer2DModel( + transformer = DiTTransformer2DModel( sample_size=16, num_layers=2, patch_size=4, diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index dd358af083..e7039c61a4 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -25,7 +25,7 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, PixArtAlphaPipeline, - Transformer2DModel, + PixArtTransformer2DModel, ) from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -53,7 +53,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) - transformer = Transformer2DModel( + transformer = PixArtTransformer2DModel( sample_size=8, num_layers=2, patch_size=2, diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index 58833d15fe..a4cc60d125 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -25,7 +25,7 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, PixArtSigmaPipeline, - Transformer2DModel, + PixArtTransformer2DModel, ) from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -53,7 +53,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) - transformer = Transformer2DModel( + transformer = PixArtTransformer2DModel( sample_size=8, num_layers=2, patch_size=2, @@ -344,7 +344,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase): def test_pixart_512(self): generator = torch.Generator("cpu").manual_seed(0) - transformer = Transformer2DModel.from_pretrained( + transformer = PixArtTransformer2DModel.from_pretrained( self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16 ) pipe = PixArtSigmaPipeline.from_pretrained( @@ -399,7 +399,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase): def test_pixart_512_without_resolution_binning(self): generator = torch.manual_seed(0) - transformer = Transformer2DModel.from_pretrained( + transformer = PixArtTransformer2DModel.from_pretrained( self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16 ) pipe = PixArtSigmaPipeline.from_pretrained(