1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

take out variant stuff

This commit is contained in:
Edna
2025-06-09 21:32:52 -06:00
committed by GitHub
parent 33ea0b65a4
commit 22ecd19f91

View File

@@ -43,40 +43,27 @@ from ..embeddings import (
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import (
AdaLayerNormContinuous,
AdaLayerNormContinuousPruned,
AdaLayerNormZero,
AdaLayerNormZeroPruned,
AdaLayerNormZeroSingle,
AdaLayerNormZeroSinglePruned,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
INVALID_VARIANT_ERRMSG = "`variant` must be `'flux' or `'chroma'`."
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
class ChromaSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
variant: str = "flux",
):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
if variant == "flux":
self.norm = AdaLayerNormZeroSingle(dim)
elif variant == "chroma":
self.norm = AdaLayerNormZeroSinglePruned(dim)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)
self.norm = AdaLayerNormZeroSinglePruned(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
@@ -132,7 +119,7 @@ class FluxSingleTransformerBlock(nn.Module):
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
class ChromaTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
@@ -140,18 +127,10 @@ class FluxTransformerBlock(nn.Module):
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
variant: str = "flux",
):
super().__init__()
if variant == "flux":
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
elif variant == "chroma":
self.norm1 = AdaLayerNormZeroPruned(dim)
self.norm1_context = AdaLayerNormZeroPruned(dim)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)
self.norm1 = AdaLayerNormZeroPruned(dim)
self.norm1_context = AdaLayerNormZeroPruned(dim)
self.attn = Attention(
query_dim=dim,
@@ -231,13 +210,13 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(
class ChromaTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
"""
The Transformer model introduced in Flux.
The Transformer model introduced in Flux, modified for Chroma.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Reference: https://huggingface.co/lodestones/Chroma
Args:
patch_size (`int`, defaults to `1`):
@@ -266,7 +245,7 @@ class FluxTransformer2DModel(
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
@@ -283,7 +262,6 @@ class FluxTransformer2DModel(
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
variant: str = "flux",
approximator_in_factor: int = 16,
approximator_hidden_dim: int = 5120,
approximator_layers: int = 5,
@@ -294,31 +272,21 @@ class FluxTransformer2DModel(
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
if variant == "flux":
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
elif variant == "chroma":
self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
factor=approximator_in_factor,
hidden_dim=approximator_hidden_dim,
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
embedding_dim=self.inner_dim,
n_layers=approximator_layers,
)
self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)
self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
factor=approximator_in_factor,
hidden_dim=approximator_hidden_dim,
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
embedding_dim=self.inner_dim,
n_layers=approximator_layers,
)
self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
ChromaTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
@@ -330,7 +298,7 @@ class FluxTransformer2DModel(
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
ChromaSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
@@ -340,16 +308,12 @@ class FluxTransformer2DModel(
]
)
norm_out_cls = AdaLayerNormContinuous if variant != "chroma" else AdaLayerNormContinuousPruned
norm_out_cls = AdaLayerNormContinuousPruned
self.norm_out = norm_out_cls(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@property
def is_chroma(self) -> bool:
return isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -506,22 +470,14 @@ class FluxTransformer2DModel(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
is_chroma = self.is_chroma
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
if not is_chroma:
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
else:
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
pooled_temb = self.distilled_guidance_layer(input_vec)
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
pooled_temb = self.distilled_guidance_layer(input_vec)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
@@ -547,18 +503,17 @@ class FluxTransformer2DModel(
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
if is_chroma:
img_offset = 3 * len(self.single_transformer_blocks)
txt_offset = img_offset + 6 * len(self.transformer_blocks)
img_modulation = img_offset + 6 * index_block
text_modulation = txt_offset + 6 * index_block
temb = torch.cat(
(
pooled_temb[:, img_modulation : img_modulation + 6],
pooled_temb[:, text_modulation : text_modulation + 6],
),
dim=1,
)
img_offset = 3 * len(self.single_transformer_blocks)
txt_offset = img_offset + 6 * len(self.transformer_blocks)
img_modulation = img_offset + 6 * index_block
text_modulation = txt_offset + 6 * index_block
temb = torch.cat(
(
pooled_temb[:, img_modulation : img_modulation + 6],
pooled_temb[:, text_modulation : text_modulation + 6],
),
dim=1,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
@@ -591,9 +546,8 @@ class FluxTransformer2DModel(
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if is_chroma:
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
@@ -621,8 +575,7 @@ class FluxTransformer2DModel(
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
if is_chroma:
temb = pooled_temb[:, -2:]
temb = pooled_temb[:, -2:]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)