mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
minor doc/test update (#9734)
* update some docs and tests! --------- Co-authored-by: Aryan <contact.aryanvs@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
This commit is contained in:
@@ -54,6 +54,11 @@ image = pipe(
|
||||
image.save("sd3_hello_world.png")
|
||||
```
|
||||
|
||||
**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family:
|
||||
- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)
|
||||
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
|
||||
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
|
||||
|
||||
## Memory Optimisations for SD3
|
||||
|
||||
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
|
||||
|
||||
@@ -16,10 +16,9 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str)
|
||||
parser.add_argument("--output_path", type=str)
|
||||
parser.add_argument("--dtype", type=str, default="fp16")
|
||||
parser.add_argument("--dtype", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
|
||||
|
||||
|
||||
def load_original_checkpoint(ckpt_path):
|
||||
@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
|
||||
def convert_sd3_transformer_checkpoint_to_diffusers(
|
||||
original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm
|
||||
):
|
||||
converted_state_dict = {}
|
||||
|
||||
# Positional and patch embeddings.
|
||||
@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||
|
||||
# qk norm
|
||||
if has_qk_norm:
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
|
||||
)
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.proj.weight"
|
||||
@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
|
||||
f"joint_blocks.{i}.context_block.attn.proj.bias"
|
||||
)
|
||||
|
||||
# attn2
|
||||
if i in dual_attention_layers:
|
||||
# Q, K, V
|
||||
sample_q2, sample_k2, sample_v2 = torch.chunk(
|
||||
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
|
||||
)
|
||||
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
|
||||
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
|
||||
|
||||
# qk norm
|
||||
if has_qk_norm:
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
|
||||
)
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn2.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn2.proj.bias"
|
||||
)
|
||||
|
||||
# norms.
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
|
||||
@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
|
||||
)
|
||||
|
||||
|
||||
def get_attn2_layers(state_dict):
|
||||
attn2_layers = []
|
||||
for key in state_dict.keys():
|
||||
if "attn2." in key:
|
||||
# Extract the layer number from the key
|
||||
layer_num = int(key.split(".")[1])
|
||||
attn2_layers.append(layer_num)
|
||||
return tuple(sorted(set(attn2_layers)))
|
||||
|
||||
|
||||
def get_pos_embed_max_size(state_dict):
|
||||
num_patches = state_dict["pos_embed"].shape[1]
|
||||
pos_embed_max_size = int(num_patches**0.5)
|
||||
return pos_embed_max_size
|
||||
|
||||
|
||||
def get_caption_projection_dim(state_dict):
|
||||
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
|
||||
return caption_projection_dim
|
||||
|
||||
|
||||
def main(args):
|
||||
original_ckpt = load_original_checkpoint(args.checkpoint_path)
|
||||
original_dtype = next(iter(original_ckpt.values())).dtype
|
||||
|
||||
# Initialize dtype with a default value
|
||||
dtype = None
|
||||
|
||||
if args.dtype is None:
|
||||
dtype = original_dtype
|
||||
elif args.dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif args.dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif args.dtype == "fp32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {args.dtype}")
|
||||
|
||||
if dtype != original_dtype:
|
||||
print(
|
||||
f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
|
||||
)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
|
||||
caption_projection_dim = 1536
|
||||
|
||||
caption_projection_dim = get_caption_projection_dim(original_ckpt)
|
||||
|
||||
# () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
|
||||
attn2_layers = get_attn2_layers(original_ckpt)
|
||||
|
||||
# sd3.5 use qk norm("rms_norm")
|
||||
has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
|
||||
|
||||
# sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
|
||||
pos_embed_max_size = get_pos_embed_max_size(original_ckpt)
|
||||
|
||||
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
|
||||
original_ckpt, num_layers, caption_projection_dim
|
||||
original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
|
||||
)
|
||||
|
||||
with CTX():
|
||||
transformer = SD3Transformer2DModel(
|
||||
sample_size=64,
|
||||
sample_size=128,
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
joint_attention_dim=4096,
|
||||
num_layers=num_layers,
|
||||
caption_projection_dim=caption_projection_dim,
|
||||
num_attention_heads=24,
|
||||
pos_embed_max_size=192,
|
||||
num_attention_heads=num_layers,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
qk_norm="rms_norm" if has_qk_norm else None,
|
||||
dual_attention_layers=attn2_layers,
|
||||
)
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
|
||||
processing of `context` conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
context_pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
use_dual_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_dual_attention = use_dual_attention
|
||||
self.context_pre_only = context_pre_only
|
||||
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
if use_dual_attention:
|
||||
self.norm1 = SD35AdaLayerNormZeroX(dim)
|
||||
else:
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
if context_norm_type == "ada_norm_continous":
|
||||
self.norm1_context = AdaLayerNormContinuous(
|
||||
@@ -118,12 +130,14 @@ class JointTransformerBlock(nn.Module):
|
||||
raise ValueError(
|
||||
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
||||
)
|
||||
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
processor = JointAttnProcessor2_0()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
@@ -134,8 +148,25 @@ class JointTransformerBlock(nn.Module):
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
if use_dual_attention:
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
@@ -159,7 +190,12 @@ class JointTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
||||
):
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
hidden_states, emb=temb
|
||||
)
|
||||
else:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
if self.context_pre_only:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
@@ -177,6 +213,11 @@ class JointTransformerBlock(nn.Module):
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
if self.use_dual_attention:
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
|
||||
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
||||
hidden_states = hidden_states + attn_output2
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
|
||||
@@ -193,7 +193,7 @@ class Attention(nn.Module):
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps)
|
||||
else:
|
||||
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
|
||||
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
|
||||
|
||||
if cross_attention_norm is None:
|
||||
self.norm_cross = None
|
||||
@@ -250,6 +250,10 @@ class Attention(nn.Module):
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
||||
)
|
||||
else:
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = None
|
||||
@@ -1050,61 +1054,72 @@ class JointAttnProcessor2_0:
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAGJointAttnProcessor2_0:
|
||||
|
||||
@@ -97,6 +97,40 @@ class FP32LayerNorm(nn.LayerNorm):
|
||||
).to(origin_dtype)
|
||||
|
||||
|
||||
class SD35AdaLayerNormZeroX(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (AdaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
|
||||
9, dim=1
|
||||
)
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
|
||||
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -69,6 +69,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
dual_attention_layers: Tuple[
|
||||
int, ...
|
||||
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
|
||||
qk_norm: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
@@ -97,6 +101,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=i == num_layers - 1,
|
||||
qk_norm=qk_norm,
|
||||
use_dual_attention=True if i in dual_attention_layers else False,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
|
||||
@@ -73,6 +73,65 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 64,
|
||||
"out_channels": 4,
|
||||
"pos_embed_max_size": 96,
|
||||
"dual_attention_layers": (),
|
||||
"qk_norm": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
|
||||
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = width = embedding_dim = 32
|
||||
pooled_embedding_dim = embedding_dim * 2
|
||||
sequence_length = 154
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 4,
|
||||
"caption_projection_dim": 32,
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 64,
|
||||
"out_channels": 4,
|
||||
"pos_embed_max_size": 96,
|
||||
"dual_attention_layers": (0,),
|
||||
"qk_norm": "rms_norm",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
Reference in New Issue
Block a user