mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[refactor] SD3 docs & remove additional code (#10882)
* update * update * update
This commit is contained in:
@@ -1410,7 +1410,7 @@ class JointAttnProcessor2_0:
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):
|
||||
|
||||
|
||||
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, defaults to `128`):
|
||||
The width/height of the latents. This is fixed during training since it is used to learn a number of
|
||||
position embeddings.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of latent channels in the input.
|
||||
num_layers (`int`, defaults to `18`):
|
||||
The number of layers of transformer blocks to use.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
num_attention_heads (`int`, defaults to `18`):
|
||||
The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, defaults to `4096`):
|
||||
The embedding dimension to use for joint text-image attention.
|
||||
caption_projection_dim (`int`, defaults to `1152`):
|
||||
The embedding dimension of caption embeddings.
|
||||
pooled_projection_dim (`int`, defaults to `2048`):
|
||||
The embedding dimension of pooled text projections.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of latent channels in the output.
|
||||
pos_embed_max_size (`int`, defaults to `96`):
|
||||
The maximum latent height/width of positional embeddings.
|
||||
extra_conditioning_channels (`int`, defaults to `0`):
|
||||
The number of extra channels to use for conditioning for patch embedding.
|
||||
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
|
||||
The number of dual-stream transformer blocks to use.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
|
||||
pos_embed_type (`str`, defaults to `"sincos"`):
|
||||
The type of positional embedding to use. Choose between `"sincos"` and `None`.
|
||||
use_pos_embed (`bool`, defaults to `True`):
|
||||
Whether to use positional embeddings.
|
||||
force_zeros_for_pooled_projection (`bool`, defaults to `True`):
|
||||
Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
|
||||
config value of the ControlNet model.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
@@ -93,7 +135,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
context_pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
use_dual_attention=True if i in dual_attention_layers else False,
|
||||
@@ -108,7 +150,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
SD3SingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -297,28 +339,28 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
@@ -437,11 +479,11 @@ class SD3MultiControlNetModel(ModelMixin):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
pooled_projections: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
|
||||
@@ -15,7 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
@@ -39,17 +38,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SD3SingleTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -59,21 +47,13 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
processor = JointAttnProcessor2_0()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
processor=JointAttnProcessor2_0(),
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
@@ -81,23 +61,17 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
|
||||
# 1. Attention
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
# Attention.
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 2. Feed Forward
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
return hidden_states
|
||||
@@ -107,26 +81,40 @@ class SD3Transformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
out_channels (`int`, defaults to 16): Number of output channels.
|
||||
|
||||
sample_size (`int`, defaults to `128`):
|
||||
The width/height of the latents. This is fixed during training since it is used to learn a number of
|
||||
position embeddings.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of latent channels in the input.
|
||||
num_layers (`int`, defaults to `18`):
|
||||
The number of layers of transformer blocks to use.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
num_attention_heads (`int`, defaults to `18`):
|
||||
The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, defaults to `4096`):
|
||||
The embedding dimension to use for joint text-image attention.
|
||||
caption_projection_dim (`int`, defaults to `1152`):
|
||||
The embedding dimension of caption embeddings.
|
||||
pooled_projection_dim (`int`, defaults to `2048`):
|
||||
The embedding dimension of pooled text projections.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of latent channels in the output.
|
||||
pos_embed_max_size (`int`, defaults to `96`):
|
||||
The maximum latent height/width of positional embeddings.
|
||||
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
|
||||
The number of dual-stream transformer blocks to use.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["JointTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
@@ -149,36 +137,33 @@ class SD3Transformer2DModel(
|
||||
qk_norm: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.out_channels = out_channels if out_channels is not None else in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=self.config.sample_size,
|
||||
width=self.config.sample_size,
|
||||
patch_size=self.config.patch_size,
|
||||
in_channels=self.config.in_channels,
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
|
||||
)
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
context_pre_only=i == num_layers - 1,
|
||||
qk_norm=qk_norm,
|
||||
use_dual_attention=True if i in dual_attention_layers else False,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -331,24 +316,24 @@ class SD3Transformer2DModel(
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
block_controlnet_hidden_states: List = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
skip_layers: Optional[List[int]] = None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
|
||||
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
|
||||
Embeddings projected from the embeddings of input conditions.
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
|
||||
Reference in New Issue
Block a user