mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] Introduce class variants for Transformer2DModel (#7647)
* init for patches * finish patched model. * continuous transformer * vectorized transformer2d. * style. * inits. * fix-copies. * introduce DiTTransformer2DModel. * fixes * use REMAPPING as suggested by @DN6 * better logging. * add pixart transformer model. * inits. * caption_channels. * attention masking. * fix use_additional_conditions. * remove print. * debug * flatten * fix: assertion for sigma * handle remapping for modeling_utils * add tests for dit transformer2d * quality * placeholder for pixart tests * pixart tests * add _no_split_modules * add docs. * check * check * check * check * fix tests * fix tests * move Transformer output to modeling_output * move errors better and bring back use_additional_conditions attribute. * add unnecessary things from DiT. * clean up pixart * fix remapping * fix device_map things in pixart2d. * replace Transformer2DModel with appropriate classes in dit, pixart tests * empty * legacy mixin classes./ * use a remapping dict for fetching class names. * change to specifc model types in the pipeline implementations. * move _fetch_remapped_cls_from_config to modeling_loading_utils.py * fix dependency problems. * add deprecation note.
This commit is contained in:
@@ -234,6 +234,10 @@
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/transformer2d
|
||||
title: Transformer2D
|
||||
- local: api/models/pixart_transformer2d
|
||||
title: PixArtTransformer2D
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2D
|
||||
- local: api/models/transformer_temporal
|
||||
title: Transformer Temporal
|
||||
- local: api/models/prior_transformer
|
||||
|
||||
19
docs/source/en/api/models/dit_transformer2d.md
Normal file
19
docs/source/en/api/models/dit_transformer2d.md
Normal file
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DiTTransformer2D
|
||||
|
||||
A Transformer model for image-like data from [DiT](https://huggingface.co/papers/2212.09748).
|
||||
|
||||
## DiTTransformer2DModel
|
||||
|
||||
[[autodoc]] DiTTransformer2DModel
|
||||
19
docs/source/en/api/models/pixart_transformer2d.md
Normal file
19
docs/source/en/api/models/pixart_transformer2d.md
Normal file
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# PixArtTransformer2D
|
||||
|
||||
A Transformer model for image-like data from [PixArt-Alpha](https://huggingface.co/papers/2310.00426) and [PixArt-Sigma](https://huggingface.co/papers/2403.04692).
|
||||
|
||||
## PixArtTransformer2DModel
|
||||
|
||||
[[autodoc]] PixArtTransformer2DModel
|
||||
@@ -82,11 +82,13 @@ else:
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"ModelMixin",
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"StableCascadeUNet",
|
||||
"T2IAdapter",
|
||||
@@ -484,11 +486,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
ModelMixin,
|
||||
MotionAdapter,
|
||||
MultiAdapter,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
T2IAdapter,
|
||||
T5FilmDecoder,
|
||||
|
||||
@@ -706,3 +706,20 @@ def flax_register_to_config(cls):
|
||||
|
||||
cls.__init__ = init
|
||||
return cls
|
||||
|
||||
|
||||
class LegacyConfigMixin(ConfigMixin):
|
||||
r"""
|
||||
A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
||||
pipeline-specific classes (like `DiTTransformer2DModel`).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
||||
# To prevent depedency import problem.
|
||||
from .models.model_loading_utils import _fetch_remapped_cls_from_config
|
||||
|
||||
# resolve remapping
|
||||
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
||||
|
||||
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
||||
|
||||
@@ -36,6 +36,8 @@ if is_torch_available():
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
|
||||
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
@@ -73,7 +75,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers import (
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
@@ -32,6 +33,13 @@ from ..utils import (
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CLASS_REMAPPING_DICT = {
|
||||
"Transformer2DModel": {
|
||||
"ada_norm_zero": "DiTTransformer2DModel",
|
||||
"ada_norm_single": "PixArtTransformer2DModel",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import infer_auto_device_map
|
||||
@@ -61,6 +69,22 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
|
||||
return device_map
|
||||
|
||||
|
||||
def _fetch_remapped_cls_from_config(config, old_class):
|
||||
previous_class_name = old_class.__name__
|
||||
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"])
|
||||
|
||||
# load diffusers library to import compatible and original scheduler
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
remapped_class = getattr(diffusers_library, remapped_class_name)
|
||||
logger.info(
|
||||
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
|
||||
"This is because `previous_class_name` is scheduled to be deprecated in a future version. Note that this"
|
||||
" DOESN'T affect the final results."
|
||||
)
|
||||
|
||||
return remapped_class
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
|
||||
@@ -15,3 +15,17 @@ class AutoencoderKLOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
|
||||
@@ -42,7 +42,11 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||
from ..utils.hub_utils import (
|
||||
PushToHubMixin,
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from .model_loading_utils import (
|
||||
_determine_device_map,
|
||||
_load_state_dict_into_model,
|
||||
@@ -1039,3 +1043,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
del module.key
|
||||
del module.value
|
||||
del module.proj_attn
|
||||
|
||||
|
||||
class LegacyModelMixin(ModelMixin):
|
||||
r"""
|
||||
A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
||||
pipeline-specific classes (like `DiTTransformer2DModel`).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
# To prevent depedency import problem.
|
||||
from .model_loading_utils import _fetch_remapped_cls_from_config
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
# load config
|
||||
config, _, _ = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
**kwargs,
|
||||
)
|
||||
# resolve remapping
|
||||
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
||||
|
||||
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
@@ -2,7 +2,9 @@ from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .dit_transformer_2d import DiTTransformer2DModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .pixart_transformer_2d import PixArtTransformer2DModel
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
|
||||
240
src/diffusers/models/transformers/dit_transformer_2d.py
Normal file
240
src/diffusers/models/transformers/dit_transformer_2d.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
|
||||
in_channels (int, defaults to 4): The number of channels in the input.
|
||||
out_channels (int, optional):
|
||||
The number of channels in the output. Specify this parameter if the output channel number differs from the
|
||||
input.
|
||||
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
|
||||
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
|
||||
norm_num_groups (int, optional, defaults to 32):
|
||||
Number of groups for group normalization within Transformer blocks.
|
||||
attention_bias (bool, optional, defaults to True):
|
||||
Configure if the Transformer blocks' attention should contain a bias parameter.
|
||||
sample_size (int, defaults to 32):
|
||||
The width of the latent images. This parameter is fixed during training.
|
||||
patch_size (int, defaults to 2):
|
||||
Size of the patches the model processes, relevant for architectures working on non-sequential data.
|
||||
activation_fn (str, optional, defaults to "gelu-approximate"):
|
||||
Activation function to use in feed-forward networks within Transformer blocks.
|
||||
num_embeds_ada_norm (int, optional, defaults to 1000):
|
||||
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
|
||||
inference.
|
||||
upcast_attention (bool, optional, defaults to False):
|
||||
If true, upcasts the attention mechanism dimensions for potentially improved performance.
|
||||
norm_type (str, optional, defaults to "ada_norm_zero"):
|
||||
Specifies the type of normalization used, can be 'ada_norm_zero'.
|
||||
norm_elementwise_affine (bool, optional, defaults to False):
|
||||
If true, enables element-wise affine parameters in the normalization layers.
|
||||
norm_eps (float, optional, defaults to 1e-5):
|
||||
A small constant added to the denominator in normalization layers to prevent division by zero.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 72,
|
||||
in_channels: int = 4,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 28,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
attention_bias: bool = True,
|
||||
sample_size: int = 32,
|
||||
patch_size: int = 2,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: Optional[int] = 1000,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "ada_norm_zero",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Validate inputs.
|
||||
if norm_type != "ada_norm_zero":
|
||||
raise NotImplementedError(
|
||||
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
|
||||
)
|
||||
elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
||||
)
|
||||
|
||||
# Set some common variables used across the board.
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# 2. Initialize the position embedding and transformer blocks.
|
||||
self.height = self.config.sample_size
|
||||
self.width = self.config.sample_size
|
||||
|
||||
self.patch_size = self.config.patch_size
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=self.config.sample_size,
|
||||
width=self.config.sample_size,
|
||||
patch_size=self.config.patch_size,
|
||||
in_channels=self.config.in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
||||
norm_eps=self.config.norm_eps,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Output blocks.
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(
|
||||
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`DiTTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
336
src/diffusers/models/transformers/pixart_transformer_2d.py
Normal file
336
src/diffusers/models/transformers/pixart_transformer_2d.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
|
||||
https://arxiv.org/abs/2403.04692).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
|
||||
in_channels (int, defaults to 4): The number of channels in the input.
|
||||
out_channels (int, optional):
|
||||
The number of channels in the output. Specify this parameter if the output channel number differs from the
|
||||
input.
|
||||
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
|
||||
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
|
||||
norm_num_groups (int, optional, defaults to 32):
|
||||
Number of groups for group normalization within Transformer blocks.
|
||||
cross_attention_dim (int, optional):
|
||||
The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
|
||||
attention_bias (bool, optional, defaults to True):
|
||||
Configure if the Transformer blocks' attention should contain a bias parameter.
|
||||
sample_size (int, defaults to 128):
|
||||
The width of the latent images. This parameter is fixed during training.
|
||||
patch_size (int, defaults to 2):
|
||||
Size of the patches the model processes, relevant for architectures working on non-sequential data.
|
||||
activation_fn (str, optional, defaults to "gelu-approximate"):
|
||||
Activation function to use in feed-forward networks within Transformer blocks.
|
||||
num_embeds_ada_norm (int, optional, defaults to 1000):
|
||||
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
|
||||
inference.
|
||||
upcast_attention (bool, optional, defaults to False):
|
||||
If true, upcasts the attention mechanism dimensions for potentially improved performance.
|
||||
norm_type (str, optional, defaults to "ada_norm_zero"):
|
||||
Specifies the type of normalization used, can be 'ada_norm_zero'.
|
||||
norm_elementwise_affine (bool, optional, defaults to False):
|
||||
If true, enables element-wise affine parameters in the normalization layers.
|
||||
norm_eps (float, optional, defaults to 1e-6):
|
||||
A small constant added to the denominator in normalization layers to prevent division by zero.
|
||||
interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
|
||||
use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
|
||||
attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
|
||||
caption_channels (int, optional, defaults to None):
|
||||
Number of channels to use for projecting the caption embeddings.
|
||||
use_linear_projection (bool, optional, defaults to False):
|
||||
Deprecated argument. Will be removed in a future version.
|
||||
num_vector_embeds (bool, optional, defaults to False):
|
||||
Deprecated argument. Will be removed in a future version.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 72,
|
||||
in_channels: int = 4,
|
||||
out_channels: Optional[int] = 8,
|
||||
num_layers: int = 28,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = 1152,
|
||||
attention_bias: bool = True,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: Optional[int] = 1000,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "ada_norm_single",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
use_additional_conditions: Optional[bool] = None,
|
||||
caption_channels: Optional[int] = None,
|
||||
attention_type: Optional[str] = "default",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Validate inputs.
|
||||
if norm_type != "ada_norm_single":
|
||||
raise NotImplementedError(
|
||||
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
|
||||
)
|
||||
elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
||||
)
|
||||
|
||||
# Set some common variables used across the board.
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if use_additional_conditions is None:
|
||||
if sample_size == 128:
|
||||
use_additional_conditions = True
|
||||
else:
|
||||
use_additional_conditions = False
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# 2. Initialize the position embedding and transformer blocks.
|
||||
self.height = self.config.sample_size
|
||||
self.width = self.config.sample_size
|
||||
|
||||
interpolation_scale = (
|
||||
self.config.interpolation_scale
|
||||
if self.config.interpolation_scale is not None
|
||||
else max(self.config.sample_size // 64, 1)
|
||||
)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=self.config.sample_size,
|
||||
width=self.config.sample_size,
|
||||
patch_size=self.config.patch_size,
|
||||
in_channels=self.config.in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
cross_attention_dim=self.config.cross_attention_dim,
|
||||
activation_fn=self.config.activation_fn,
|
||||
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
||||
norm_eps=self.config.norm_eps,
|
||||
attention_type=self.config.attention_type,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Output blocks.
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
||||
)
|
||||
self.caption_projection = None
|
||||
if self.config.caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.config.caption_channels, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`PixArtTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep (`torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size = hidden_states.shape[0]
|
||||
height, width = (
|
||||
hidden_states.shape[-2] // self.config.patch_size,
|
||||
hidden_states.shape[-1] // self.config.patch_size,
|
||||
)
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if self.caption_projection is not None:
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
None,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=None,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
shift, scale = (
|
||||
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
|
||||
).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -11,39 +11,30 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, deprecate, is_torch_version, logging
|
||||
from ...configuration_utils import LegacyConfigMixin, register_to_config
|
||||
from ...utils import deprecate, is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import LegacyModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.Tensor
|
||||
class Transformer2DModelOutput(Transformer2DModelOutput):
|
||||
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
|
||||
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
@@ -116,40 +107,12 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
||||
)
|
||||
|
||||
# Set some common variables used across the board.
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.interpolation_scale = interpolation_scale
|
||||
self.caption_channels = caption_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.gradient_checkpointing = False
|
||||
if use_additional_conditions is None:
|
||||
if norm_type == "ada_norm_single" and sample_size == 128:
|
||||
use_additional_conditions = True
|
||||
else:
|
||||
use_additional_conditions = False
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
@@ -166,6 +129,35 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
||||
)
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
# Set some common variables used across the board.
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.interpolation_scale = interpolation_scale
|
||||
self.caption_channels = caption_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
if use_additional_conditions is None:
|
||||
if norm_type == "ada_norm_single" and sample_size == 128:
|
||||
use_additional_conditions = True
|
||||
else:
|
||||
use_additional_conditions = False
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
# 2. Initialize the right blocks.
|
||||
# These functions follow a common structure:
|
||||
# a. Initialize the input blocks. b. Initialize the transformer blocks.
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import AutoencoderKL, Transformer2DModel
|
||||
from ...models import AutoencoderKL, DiTTransformer2DModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
@@ -36,8 +36,8 @@ class DiTPipeline(DiffusionPipeline):
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Parameters:
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A class conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
transformer ([`DiTTransformer2DModel`]):
|
||||
A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
scheduler ([`DDIMScheduler`]):
|
||||
@@ -48,7 +48,7 @@ class DiTPipeline(DiffusionPipeline):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: Transformer2DModel,
|
||||
transformer: DiTTransformer2DModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
id2label: Optional[Dict[int, str]] = None,
|
||||
|
||||
@@ -608,6 +608,7 @@ def load_sub_model(
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
|
||||
# retrieve class candidates
|
||||
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...image_processor import PixArtImageProcessor
|
||||
from ...models import AutoencoderKL, Transformer2DModel
|
||||
from ...models import AutoencoderKL, PixArtTransformer2DModel
|
||||
from ...schedulers import DPMSolverMultistepScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -246,8 +246,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
transformer ([`PixArtTransformer2DModel`]):
|
||||
A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
@@ -276,7 +276,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKL,
|
||||
transformer: Transformer2DModel,
|
||||
transformer: PixArtTransformer2DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...image_processor import PixArtImageProcessor
|
||||
from ...models import AutoencoderKL, Transformer2DModel
|
||||
from ...models import AutoencoderKL, PixArtTransformer2DModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -202,7 +202,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKL,
|
||||
transformer: Transformer2DModel,
|
||||
transformer: PixArtTransformer2DModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -107,6 +107,21 @@ class ControlNetXSAdapter(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DiTTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class I2VGenXLUNet(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -182,6 +197,21 @@ class MultiAdapter(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PriorTransformer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -559,7 +559,7 @@ class ModelTesterMixin:
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, expected_max_diff)
|
||||
|
||||
def test_output(self):
|
||||
def test_output(self, expected_output_shape=None):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
@@ -575,8 +575,12 @@ class ModelTesterMixin:
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
if expected_output_shape is None:
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
else:
|
||||
self.assertEqual(output.shape, expected_output_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
95
tests/models/transformers/test_models_dit_transformer2d.py
Normal file
95
tests/models/transformers/test_models_dit_transformer2d.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiTTransformer2DModel, Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class DiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = DiTTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
in_channels = 4
|
||||
sample_size = 8
|
||||
scheduler_num_train_steps = 1000
|
||||
num_class_labels = 4
|
||||
|
||||
hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device)
|
||||
timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device)
|
||||
class_label_ids = torch.randint(0, num_class_labels, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_label_ids}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (8, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 8,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 4,
|
||||
"attention_bias": True,
|
||||
"num_layers": 1,
|
||||
"norm_type": "ada_norm_zero",
|
||||
"num_embeds_ada_norm": 8,
|
||||
"patch_size": 2,
|
||||
"sample_size": 8,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_output(self):
|
||||
super().test_output(
|
||||
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
|
||||
)
|
||||
|
||||
def test_correct_class_remapping_from_dict_config(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = Transformer2DModel.from_config(init_dict)
|
||||
assert isinstance(model, DiTTransformer2DModel)
|
||||
|
||||
def test_correct_class_remapping_from_pretrained_config(self):
|
||||
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
|
||||
model = Transformer2DModel.from_config(config)
|
||||
assert isinstance(model, DiTTransformer2DModel)
|
||||
|
||||
@slow
|
||||
def test_correct_class_remapping(self):
|
||||
model = Transformer2DModel.from_pretrained("facebook/DiT-XL-2-256", subfolder="transformer")
|
||||
assert isinstance(model, DiTTransformer2DModel)
|
||||
108
tests/models/transformers/test_models_pixart_transformer2d.py
Normal file
108
tests/models/transformers/test_models_pixart_transformer2d.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import PixArtTransformer2DModel, Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class PixArtTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = PixArtTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
in_channels = 4
|
||||
sample_size = 8
|
||||
scheduler_num_train_steps = 1000
|
||||
cross_attention_dim = 8
|
||||
seq_len = 8
|
||||
|
||||
hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device)
|
||||
timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, seq_len, cross_attention_dim)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timesteps,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_cond_kwargs": {"aspect_ratio": None, "resolution": None},
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (8, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 8,
|
||||
"num_layers": 1,
|
||||
"patch_size": 2,
|
||||
"attention_head_dim": 2,
|
||||
"num_attention_heads": 2,
|
||||
"in_channels": 4,
|
||||
"cross_attention_dim": 8,
|
||||
"out_channels": 8,
|
||||
"attention_bias": True,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"num_embeds_ada_norm": 8,
|
||||
"norm_type": "ada_norm_single",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"use_additional_conditions": False,
|
||||
"caption_channels": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_output(self):
|
||||
super().test_output(
|
||||
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
|
||||
)
|
||||
|
||||
def test_correct_class_remapping_from_dict_config(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = Transformer2DModel.from_config(init_dict)
|
||||
assert isinstance(model, PixArtTransformer2DModel)
|
||||
|
||||
def test_correct_class_remapping_from_pretrained_config(self):
|
||||
config = PixArtTransformer2DModel.load_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")
|
||||
model = Transformer2DModel.from_config(config)
|
||||
assert isinstance(model, PixArtTransformer2DModel)
|
||||
|
||||
@slow
|
||||
def test_correct_class_remapping(self):
|
||||
model = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")
|
||||
assert isinstance(model, PixArtTransformer2DModel)
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler
|
||||
from diffusers.utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
|
||||
|
||||
@@ -46,7 +46,7 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = Transformer2DModel(
|
||||
transformer = DiTTransformer2DModel(
|
||||
sample_size=16,
|
||||
num_layers=2,
|
||||
patch_size=4,
|
||||
|
||||
@@ -25,7 +25,7 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
PixArtAlphaPipeline,
|
||||
Transformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
@@ -53,7 +53,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = Transformer2DModel(
|
||||
transformer = PixArtTransformer2DModel(
|
||||
sample_size=8,
|
||||
num_layers=2,
|
||||
patch_size=2,
|
||||
|
||||
@@ -25,7 +25,7 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
PixArtSigmaPipeline,
|
||||
Transformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
@@ -53,7 +53,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = Transformer2DModel(
|
||||
transformer = PixArtTransformer2DModel(
|
||||
sample_size=8,
|
||||
num_layers=2,
|
||||
patch_size=2,
|
||||
@@ -344,7 +344,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_pixart_512(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
transformer = PixArtTransformer2DModel.from_pretrained(
|
||||
self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = PixArtSigmaPipeline.from_pretrained(
|
||||
@@ -399,7 +399,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_pixart_512_without_resolution_binning(self):
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
transformer = PixArtTransformer2DModel.from_pretrained(
|
||||
self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = PixArtSigmaPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user