1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[refactor] SD3 docs & remove additional code (#10882)

* update

* update

* update
This commit is contained in:
Aryan
2025-02-25 03:08:47 +05:30
committed by GitHub
parent 87599691b9
commit 13f20c7fe8
3 changed files with 107 additions and 80 deletions

View File

@@ -1410,7 +1410,7 @@ class JointAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,

View File

@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
Parameters:
sample_size (`int`, defaults to `128`):
The width/height of the latents. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`):
The number of latent channels in the input.
num_layers (`int`, defaults to `18`):
The number of layers of transformer blocks to use.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `18`):
The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`):
The embedding dimension to use for joint text-image attention.
caption_projection_dim (`int`, defaults to `1152`):
The embedding dimension of caption embeddings.
pooled_projection_dim (`int`, defaults to `2048`):
The embedding dimension of pooled text projections.
out_channels (`int`, defaults to `16`):
The number of latent channels in the output.
pos_embed_max_size (`int`, defaults to `96`):
The maximum latent height/width of positional embeddings.
extra_conditioning_channels (`int`, defaults to `0`):
The number of extra channels to use for conditioning for patch embedding.
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
The number of dual-stream transformer blocks to use.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
pos_embed_type (`str`, defaults to `"sincos"`):
The type of positional embedding to use. Choose between `"sincos"` and `None`.
use_pos_embed (`bool`, defaults to `True`):
Whether to use positional embeddings.
force_zeros_for_pooled_projection (`bool`, defaults to `True`):
Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
config value of the ControlNet model.
"""
_supports_gradient_checkpointing = True
@register_to_config
@@ -93,7 +135,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
attention_head_dim=attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
@@ -108,7 +150,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
SD3SingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
@@ -297,28 +339,28 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
def forward(
self,
hidden_states: torch.FloatTensor,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
@@ -437,11 +479,11 @@ class SD3MultiControlNetModel(ModelMixin):
def forward(
self,
hidden_states: torch.FloatTensor,
hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
pooled_projections: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,

View File

@@ -15,7 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
@@ -39,17 +38,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class SD3SingleTransformerBlock(nn.Module):
r"""
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
"""
def __init__(
self,
dim: int,
@@ -59,21 +47,13 @@ class SD3SingleTransformerBlock(nn.Module):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
processor=JointAttnProcessor2_0(),
eps=1e-6,
)
@@ -81,23 +61,17 @@ class SD3SingleTransformerBlock(nn.Module):
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
# 1. Attention
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)
# Process attention outputs for the `hidden_states`.
attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
# 2. Feed Forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
return hidden_states
@@ -107,26 +81,40 @@ class SD3Transformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
):
"""
The Transformer model introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
out_channels (`int`, defaults to 16): Number of output channels.
sample_size (`int`, defaults to `128`):
The width/height of the latents. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`):
The number of latent channels in the input.
num_layers (`int`, defaults to `18`):
The number of layers of transformer blocks to use.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `18`):
The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`):
The embedding dimension to use for joint text-image attention.
caption_projection_dim (`int`, defaults to `1152`):
The embedding dimension of caption embeddings.
pooled_projection_dim (`int`, defaults to `2048`):
The embedding dimension of pooled text projections.
out_channels (`int`, defaults to `16`):
The number of latent channels in the output.
pos_embed_max_size (`int`, defaults to `96`):
The maximum latent height/width of positional embeddings.
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
The number of dual-stream transformer blocks to use.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["JointTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
@@ -149,36 +137,33 @@ class SD3Transformer2DModel(
qk_norm: Optional[str] = None,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.out_channels = out_channels if out_channels is not None else in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = PatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
context_pre_only=i == num_layers - 1,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(self.config.num_layers)
for i in range(num_layers)
]
)
@@ -331,24 +316,24 @@ class SD3Transformer2DModel(
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
skip_layers: Optional[List[int]] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
Embeddings projected from the embeddings of input conditions.
timestep (`torch.LongTensor`):
Used to indicate denoising step.