1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Core] Introduce class variants for Transformer2DModel (#7647)

* init for patches

* finish patched model.

* continuous transformer

* vectorized transformer2d.

* style.

* inits.

* fix-copies.

* introduce DiTTransformer2DModel.

* fixes

* use REMAPPING as suggested by @DN6

* better logging.

* add pixart transformer model.

* inits.

* caption_channels.

* attention masking.

* fix use_additional_conditions.

* remove print.

* debug

* flatten

* fix: assertion for sigma

* handle remapping for modeling_utils

* add tests for dit transformer2d

* quality

* placeholder for pixart tests

* pixart tests

* add _no_split_modules

* add docs.

* check

* check

* check

* check

* fix tests

* fix tests

* move Transformer output to modeling_output

* move errors better and bring back use_additional_conditions attribute.

* add unnecessary things from DiT.

* clean up pixart

* fix remapping

* fix device_map things in pixart2d.

* replace Transformer2DModel with appropriate classes in dit, pixart tests

* empty

* legacy mixin classes./

* use a remapping dict for fetching class names.

* change to specifc model types in the pipeline implementations.

* move _fetch_remapped_cls_from_config to modeling_loading_utils.py

* fix dependency problems.

* add deprecation note.
This commit is contained in:
Sayak Paul
2024-05-31 13:40:27 +05:30
committed by sayakpaul
parent 7828d4eb00
commit 137403ff31
24 changed files with 1036 additions and 67 deletions

View File

@@ -234,6 +234,10 @@
title: ConsistencyDecoderVAE
- local: api/models/transformer2d
title: Transformer2D
- local: api/models/pixart_transformer2d
title: PixArtTransformer2D
- local: api/models/dit_transformer2d
title: DiTTransformer2D
- local: api/models/transformer_temporal
title: Transformer Temporal
- local: api/models/prior_transformer

View File

@@ -0,0 +1,19 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# DiTTransformer2D
A Transformer model for image-like data from [DiT](https://huggingface.co/papers/2212.09748).
## DiTTransformer2DModel
[[autodoc]] DiTTransformer2DModel

View File

@@ -0,0 +1,19 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PixArtTransformer2D
A Transformer model for image-like data from [PixArt-Alpha](https://huggingface.co/papers/2310.00426) and [PixArt-Sigma](https://huggingface.co/papers/2403.04692).
## PixArtTransformer2DModel
[[autodoc]] PixArtTransformer2DModel

View File

@@ -82,11 +82,13 @@ else:
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
"PixArtTransformer2DModel",
"PriorTransformer",
"StableCascadeUNet",
"T2IAdapter",
@@ -484,11 +486,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
I2VGenXLUNet,
Kandinsky3UNet,
ModelMixin,
MotionAdapter,
MultiAdapter,
PixArtTransformer2DModel,
PriorTransformer,
T2IAdapter,
T5FilmDecoder,

View File

@@ -706,3 +706,20 @@ def flax_register_to_config(cls):
cls.__init__ = init
return cls
class LegacyConfigMixin(ConfigMixin):
r"""
A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
pipeline-specific classes (like `DiTTransformer2DModel`).
"""
@classmethod
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
# To prevent depedency import problem.
from .models.model_loading_utils import _fetch_remapped_cls_from_config
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)

View File

@@ -36,6 +36,8 @@ if is_torch_available():
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
@@ -73,7 +75,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
DiTTransformer2DModel,
DualTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
T5FilmDecoder,
Transformer2DModel,

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
from collections import OrderedDict
@@ -32,6 +33,13 @@ from ..utils import (
logger = logging.get_logger(__name__)
_CLASS_REMAPPING_DICT = {
"Transformer2DModel": {
"ada_norm_zero": "DiTTransformer2DModel",
"ada_norm_single": "PixArtTransformer2DModel",
}
}
if is_accelerate_available():
from accelerate import infer_auto_device_map
@@ -61,6 +69,22 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
return device_map
def _fetch_remapped_cls_from_config(config, old_class):
previous_class_name = old_class.__name__
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"])
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
remapped_class = getattr(diffusers_library, remapped_class_name)
logger.info(
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
"This is because `previous_class_name` is scheduled to be deprecated in a future version. Note that this"
" DOESN'T affect the final results."
)
return remapped_class
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.

View File

@@ -15,3 +15,17 @@ class AutoencoderKLOutput(BaseOutput):
"""
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: "torch.Tensor" # noqa: F821

View File

@@ -42,7 +42,11 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
from ..utils.hub_utils import (
PushToHubMixin,
load_or_create_model_card,
populate_model_card,
)
from .model_loading_utils import (
_determine_device_map,
_load_state_dict_into_model,
@@ -1039,3 +1043,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
del module.key
del module.value
del module.proj_attn
class LegacyModelMixin(ModelMixin):
r"""
A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
pipeline-specific classes (like `DiTTransformer2DModel`).
"""
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# To prevent depedency import problem.
from .model_loading_utils import _fetch_remapped_cls_from_config
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
# load config
config, _, _ = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
**kwargs,
)
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

View File

@@ -2,7 +2,9 @@ from ...utils import is_torch_available
if is_torch_available():
from .dit_transformer_2d import DiTTransformer2DModel
from .dual_transformer_2d import DualTransformer2DModel
from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel

View File

@@ -0,0 +1,240 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
r"""
A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
Parameters:
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
in_channels (int, defaults to 4): The number of channels in the input.
out_channels (int, optional):
The number of channels in the output. Specify this parameter if the output channel number differs from the
input.
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
norm_num_groups (int, optional, defaults to 32):
Number of groups for group normalization within Transformer blocks.
attention_bias (bool, optional, defaults to True):
Configure if the Transformer blocks' attention should contain a bias parameter.
sample_size (int, defaults to 32):
The width of the latent images. This parameter is fixed during training.
patch_size (int, defaults to 2):
Size of the patches the model processes, relevant for architectures working on non-sequential data.
activation_fn (str, optional, defaults to "gelu-approximate"):
Activation function to use in feed-forward networks within Transformer blocks.
num_embeds_ada_norm (int, optional, defaults to 1000):
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
inference.
upcast_attention (bool, optional, defaults to False):
If true, upcasts the attention mechanism dimensions for potentially improved performance.
norm_type (str, optional, defaults to "ada_norm_zero"):
Specifies the type of normalization used, can be 'ada_norm_zero'.
norm_elementwise_affine (bool, optional, defaults to False):
If true, enables element-wise affine parameters in the normalization layers.
norm_eps (float, optional, defaults to 1e-5):
A small constant added to the denominator in normalization layers to prevent division by zero.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 72,
in_channels: int = 4,
out_channels: Optional[int] = None,
num_layers: int = 28,
dropout: float = 0.0,
norm_num_groups: int = 32,
attention_bias: bool = True,
sample_size: int = 32,
patch_size: int = 2,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: Optional[int] = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm_zero",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
):
super().__init__()
# Validate inputs.
if norm_type != "ada_norm_zero":
raise NotImplementedError(
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
)
elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
raise ValueError(
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
)
# Set some common variables used across the board.
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = in_channels if out_channels is None else out_channels
self.gradient_checkpointing = False
# 2. Initialize the position embedding and transformer blocks.
self.height = self.config.sample_size
self.width = self.config.sample_size
self.patch_size = self.config.patch_size
self.pos_embed = PatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
)
for _ in range(self.config.num_layers)
]
)
# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict: bool = True,
):
"""
The [`DiTTransformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 1. Input
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
None,
None,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
# unpatchify
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -0,0 +1,336 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
r"""
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
https://arxiv.org/abs/2403.04692).
Parameters:
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
in_channels (int, defaults to 4): The number of channels in the input.
out_channels (int, optional):
The number of channels in the output. Specify this parameter if the output channel number differs from the
input.
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
norm_num_groups (int, optional, defaults to 32):
Number of groups for group normalization within Transformer blocks.
cross_attention_dim (int, optional):
The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
attention_bias (bool, optional, defaults to True):
Configure if the Transformer blocks' attention should contain a bias parameter.
sample_size (int, defaults to 128):
The width of the latent images. This parameter is fixed during training.
patch_size (int, defaults to 2):
Size of the patches the model processes, relevant for architectures working on non-sequential data.
activation_fn (str, optional, defaults to "gelu-approximate"):
Activation function to use in feed-forward networks within Transformer blocks.
num_embeds_ada_norm (int, optional, defaults to 1000):
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
inference.
upcast_attention (bool, optional, defaults to False):
If true, upcasts the attention mechanism dimensions for potentially improved performance.
norm_type (str, optional, defaults to "ada_norm_zero"):
Specifies the type of normalization used, can be 'ada_norm_zero'.
norm_elementwise_affine (bool, optional, defaults to False):
If true, enables element-wise affine parameters in the normalization layers.
norm_eps (float, optional, defaults to 1e-6):
A small constant added to the denominator in normalization layers to prevent division by zero.
interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
caption_channels (int, optional, defaults to None):
Number of channels to use for projecting the caption embeddings.
use_linear_projection (bool, optional, defaults to False):
Deprecated argument. Will be removed in a future version.
num_vector_embeds (bool, optional, defaults to False):
Deprecated argument. Will be removed in a future version.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 72,
in_channels: int = 4,
out_channels: Optional[int] = 8,
num_layers: int = 28,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = 1152,
attention_bias: bool = True,
sample_size: int = 128,
patch_size: int = 2,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: Optional[int] = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
use_additional_conditions: Optional[bool] = None,
caption_channels: Optional[int] = None,
attention_type: Optional[str] = "default",
):
super().__init__()
# Validate inputs.
if norm_type != "ada_norm_single":
raise NotImplementedError(
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
)
elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
raise ValueError(
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
)
# Set some common variables used across the board.
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = in_channels if out_channels is None else out_channels
if use_additional_conditions is None:
if sample_size == 128:
use_additional_conditions = True
else:
use_additional_conditions = False
self.use_additional_conditions = use_additional_conditions
self.gradient_checkpointing = False
# 2. Initialize the position embedding and transformer blocks.
self.height = self.config.sample_size
self.width = self.config.sample_size
interpolation_scale = (
self.config.interpolation_scale
if self.config.interpolation_scale is not None
else max(self.config.sample_size // 64, 1)
)
self.pos_embed = PatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
interpolation_scale=interpolation_scale,
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
cross_attention_dim=self.config.cross_attention_dim,
activation_fn=self.config.activation_fn,
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
)
for _ in range(self.config.num_layers)
]
)
# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=self.use_additional_conditions
)
self.caption_projection = None
if self.config.caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.config.caption_channels, hidden_size=self.inner_dim
)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`PixArtTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep (`torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.config.patch_size,
hidden_states.shape[-1] // self.config.patch_size,
)
hidden_states = self.pos_embed(hidden_states)
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.caption_projection is not None:
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=None,
)
# 3. Output
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -11,39 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, deprecate, is_torch_version, logging
from ...configuration_utils import LegacyConfigMixin, register_to_config
from ...utils import deprecate, is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..modeling_utils import ModelMixin
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import LegacyModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.Tensor
class Transformer2DModelOutput(Transformer2DModelOutput):
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
class Transformer2DModel(ModelMixin, ConfigMixin):
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
"""
A 2D Transformer model for image-like data.
@@ -116,40 +107,12 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
)
# Set some common variables used across the board.
self.use_linear_projection = use_linear_projection
self.interpolation_scale = interpolation_scale
self.caption_channels = caption_channels
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.gradient_checkpointing = False
if use_additional_conditions is None:
if norm_type == "ada_norm_single" and sample_size == 128:
use_additional_conditions = True
else:
use_additional_conditions = False
self.use_additional_conditions = use_additional_conditions
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
@@ -166,6 +129,35 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
# Set some common variables used across the board.
self.use_linear_projection = use_linear_projection
self.interpolation_scale = interpolation_scale
self.caption_channels = caption_channels
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.gradient_checkpointing = False
if use_additional_conditions is None:
if norm_type == "ada_norm_single" and sample_size == 128:
use_additional_conditions = True
else:
use_additional_conditions = False
self.use_additional_conditions = use_additional_conditions
# 2. Initialize the right blocks.
# These functions follow a common structure:
# a. Initialize the input blocks. b. Initialize the transformer blocks.

View File

@@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
from ...models import AutoencoderKL, Transformer2DModel
from ...models import AutoencoderKL, DiTTransformer2DModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -36,8 +36,8 @@ class DiTPipeline(DiffusionPipeline):
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Parameters:
transformer ([`Transformer2DModel`]):
A class conditioned `Transformer2DModel` to denoise the encoded image latents.
transformer ([`DiTTransformer2DModel`]):
A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
scheduler ([`DDIMScheduler`]):
@@ -48,7 +48,7 @@ class DiTPipeline(DiffusionPipeline):
def __init__(
self,
transformer: Transformer2DModel,
transformer: DiTTransformer2DModel,
vae: AutoencoderKL,
scheduler: KarrasDiffusionSchedulers,
id2label: Optional[Dict[int, str]] = None,

View File

@@ -608,6 +608,7 @@ def load_sub_model(
cached_folder: Union[str, os.PathLike],
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(

View File

@@ -22,7 +22,7 @@ import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import PixArtImageProcessor
from ...models import AutoencoderKL, Transformer2DModel
from ...models import AutoencoderKL, PixArtTransformer2DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -246,8 +246,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
transformer ([`PixArtTransformer2DModel`]):
A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
@@ -276,7 +276,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: Transformer2DModel,
transformer: PixArtTransformer2DModel,
scheduler: DPMSolverMultistepScheduler,
):
super().__init__()

View File

@@ -22,7 +22,7 @@ import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import PixArtImageProcessor
from ...models import AutoencoderKL, Transformer2DModel
from ...models import AutoencoderKL, PixArtTransformer2DModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
BACKENDS_MAPPING,
@@ -202,7 +202,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: Transformer2DModel,
transformer: PixArtTransformer2DModel,
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()

View File

@@ -107,6 +107,21 @@ class ControlNetXSAdapter(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class DiTTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"]
@@ -182,6 +197,21 @@ class MultiAdapter(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PriorTransformer(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -559,7 +559,7 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, expected_max_diff)
def test_output(self):
def test_output(self, expected_output_shape=None):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
@@ -575,8 +575,12 @@ class ModelTesterMixin:
# input & output have to have the same shape
input_tensor = inputs_dict[self.main_input_name]
expected_shape = input_tensor.shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
if expected_output_shape is None:
expected_shape = input_tensor.shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
else:
self.assertEqual(output.shape, expected_output_shape, "Input and output shapes do not match")
def test_model_from_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -0,0 +1,95 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import DiTTransformer2DModel, Transformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
slow,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class DiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = DiTTransformer2DModel
main_input_name = "hidden_states"
@property
def dummy_input(self):
batch_size = 4
in_channels = 4
sample_size = 8
scheduler_num_train_steps = 1000
num_class_labels = 4
hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device)
timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device)
class_label_ids = torch.randint(0, num_class_labels, size=(batch_size,)).to(torch_device)
return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_label_ids}
@property
def input_shape(self):
return (4, 8, 8)
@property
def output_shape(self):
return (8, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 8,
"activation_fn": "gelu-approximate",
"num_attention_heads": 2,
"attention_head_dim": 4,
"attention_bias": True,
"num_layers": 1,
"norm_type": "ada_norm_zero",
"num_embeds_ada_norm": 8,
"patch_size": 2,
"sample_size": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)
def test_correct_class_remapping_from_dict_config(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = Transformer2DModel.from_config(init_dict)
assert isinstance(model, DiTTransformer2DModel)
def test_correct_class_remapping_from_pretrained_config(self):
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
model = Transformer2DModel.from_config(config)
assert isinstance(model, DiTTransformer2DModel)
@slow
def test_correct_class_remapping(self):
model = Transformer2DModel.from_pretrained("facebook/DiT-XL-2-256", subfolder="transformer")
assert isinstance(model, DiTTransformer2DModel)

View File

@@ -0,0 +1,108 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import PixArtTransformer2DModel, Transformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
slow,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class PixArtTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = PixArtTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
@property
def dummy_input(self):
batch_size = 4
in_channels = 4
sample_size = 8
scheduler_num_train_steps = 1000
cross_attention_dim = 8
seq_len = 8
hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device)
timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, seq_len, cross_attention_dim)).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timesteps,
"encoder_hidden_states": encoder_hidden_states,
"added_cond_kwargs": {"aspect_ratio": None, "resolution": None},
}
@property
def input_shape(self):
return (4, 8, 8)
@property
def output_shape(self):
return (8, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 8,
"num_layers": 1,
"patch_size": 2,
"attention_head_dim": 2,
"num_attention_heads": 2,
"in_channels": 4,
"cross_attention_dim": 8,
"out_channels": 8,
"attention_bias": True,
"activation_fn": "gelu-approximate",
"num_embeds_ada_norm": 8,
"norm_type": "ada_norm_single",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"use_additional_conditions": False,
"caption_channels": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)
def test_correct_class_remapping_from_dict_config(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = Transformer2DModel.from_config(init_dict)
assert isinstance(model, PixArtTransformer2DModel)
def test_correct_class_remapping_from_pretrained_config(self):
config = PixArtTransformer2DModel.load_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")
model = Transformer2DModel.from_config(config)
assert isinstance(model, PixArtTransformer2DModel)
@slow
def test_correct_class_remapping(self):
model = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")
assert isinstance(model, PixArtTransformer2DModel)

View File

@@ -19,7 +19,7 @@ import unittest
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
@@ -46,7 +46,7 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
transformer = Transformer2DModel(
transformer = DiTTransformer2DModel(
sample_size=16,
num_layers=2,
patch_size=4,

View File

@@ -25,7 +25,7 @@ from diffusers import (
AutoencoderKL,
DDIMScheduler,
PixArtAlphaPipeline,
Transformer2DModel,
PixArtTransformer2DModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
@@ -53,7 +53,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
transformer = Transformer2DModel(
transformer = PixArtTransformer2DModel(
sample_size=8,
num_layers=2,
patch_size=2,

View File

@@ -25,7 +25,7 @@ from diffusers import (
AutoencoderKL,
DDIMScheduler,
PixArtSigmaPipeline,
Transformer2DModel,
PixArtTransformer2DModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
@@ -53,7 +53,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
transformer = Transformer2DModel(
transformer = PixArtTransformer2DModel(
sample_size=8,
num_layers=2,
patch_size=2,
@@ -344,7 +344,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
def test_pixart_512(self):
generator = torch.Generator("cpu").manual_seed(0)
transformer = Transformer2DModel.from_pretrained(
transformer = PixArtTransformer2DModel.from_pretrained(
self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16
)
pipe = PixArtSigmaPipeline.from_pretrained(
@@ -399,7 +399,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
def test_pixart_512_without_resolution_binning(self):
generator = torch.manual_seed(0)
transformer = Transformer2DModel.from_pretrained(
transformer = PixArtTransformer2DModel.from_pretrained(
self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16
)
pipe = PixArtSigmaPipeline.from_pretrained(