mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
take out variant stuff
This commit is contained in:
@@ -43,40 +43,27 @@ from ..embeddings import (
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import (
|
||||
AdaLayerNormContinuous,
|
||||
AdaLayerNormContinuousPruned,
|
||||
AdaLayerNormZero,
|
||||
AdaLayerNormZeroPruned,
|
||||
AdaLayerNormZeroSingle,
|
||||
AdaLayerNormZeroSinglePruned,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
INVALID_VARIANT_ERRMSG = "`variant` must be `'flux' or `'chroma'`."
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxSingleTransformerBlock(nn.Module):
|
||||
class ChromaSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
variant: str = "flux",
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
if variant == "flux":
|
||||
self.norm = AdaLayerNormZeroSingle(dim)
|
||||
elif variant == "chroma":
|
||||
self.norm = AdaLayerNormZeroSinglePruned(dim)
|
||||
else:
|
||||
raise ValueError(INVALID_VARIANT_ERRMSG)
|
||||
|
||||
self.norm = AdaLayerNormZeroSinglePruned(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
@@ -132,7 +119,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxTransformerBlock(nn.Module):
|
||||
class ChromaTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -140,18 +127,10 @@ class FluxTransformerBlock(nn.Module):
|
||||
attention_head_dim: int,
|
||||
qk_norm: str = "rms_norm",
|
||||
eps: float = 1e-6,
|
||||
variant: str = "flux",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if variant == "flux":
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
elif variant == "chroma":
|
||||
self.norm1 = AdaLayerNormZeroPruned(dim)
|
||||
self.norm1_context = AdaLayerNormZeroPruned(dim)
|
||||
else:
|
||||
raise ValueError(INVALID_VARIANT_ERRMSG)
|
||||
self.norm1 = AdaLayerNormZeroPruned(dim)
|
||||
self.norm1_context = AdaLayerNormZeroPruned(dim)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
@@ -231,13 +210,13 @@ class FluxTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
class ChromaTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
The Transformer model introduced in Flux, modified for Chroma.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
Reference: https://huggingface.co/lodestones/Chroma
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `1`):
|
||||
@@ -266,7 +245,7 @@ class FluxTransformer2DModel(
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
@@ -283,7 +262,6 @@ class FluxTransformer2DModel(
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
|
||||
variant: str = "flux",
|
||||
approximator_in_factor: int = 16,
|
||||
approximator_hidden_dim: int = 5120,
|
||||
approximator_layers: int = 5,
|
||||
@@ -294,31 +272,21 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
|
||||
if variant == "flux":
|
||||
text_time_guidance_cls = (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
||||
)
|
||||
self.time_text_embed = text_time_guidance_cls(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
elif variant == "chroma":
|
||||
self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
|
||||
factor=approximator_in_factor,
|
||||
hidden_dim=approximator_hidden_dim,
|
||||
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
|
||||
embedding_dim=self.inner_dim,
|
||||
n_layers=approximator_layers,
|
||||
)
|
||||
self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
|
||||
else:
|
||||
raise ValueError(INVALID_VARIANT_ERRMSG)
|
||||
self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
|
||||
factor=approximator_in_factor,
|
||||
hidden_dim=approximator_hidden_dim,
|
||||
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
|
||||
embedding_dim=self.inner_dim,
|
||||
n_layers=approximator_layers,
|
||||
)
|
||||
self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
|
||||
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
ChromaTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
@@ -330,7 +298,7 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
ChromaSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
@@ -340,16 +308,12 @@ class FluxTransformer2DModel(
|
||||
]
|
||||
)
|
||||
|
||||
norm_out_cls = AdaLayerNormContinuous if variant != "chroma" else AdaLayerNormContinuousPruned
|
||||
norm_out_cls = AdaLayerNormContinuousPruned
|
||||
self.norm_out = norm_out_cls(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def is_chroma(self) -> bool:
|
||||
return isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
@@ -506,22 +470,14 @@ class FluxTransformer2DModel(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
is_chroma = self.is_chroma
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
if not is_chroma:
|
||||
temb = (
|
||||
self.time_text_embed(timestep, pooled_projections)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
)
|
||||
else:
|
||||
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
pooled_temb = self.distilled_guidance_layer(input_vec)
|
||||
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
pooled_temb = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
@@ -547,18 +503,17 @@ class FluxTransformer2DModel(
|
||||
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if is_chroma:
|
||||
img_offset = 3 * len(self.single_transformer_blocks)
|
||||
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
||||
img_modulation = img_offset + 6 * index_block
|
||||
text_modulation = txt_offset + 6 * index_block
|
||||
temb = torch.cat(
|
||||
(
|
||||
pooled_temb[:, img_modulation : img_modulation + 6],
|
||||
pooled_temb[:, text_modulation : text_modulation + 6],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
img_offset = 3 * len(self.single_transformer_blocks)
|
||||
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
||||
img_modulation = img_offset + 6 * index_block
|
||||
text_modulation = txt_offset + 6 * index_block
|
||||
temb = torch.cat(
|
||||
(
|
||||
pooled_temb[:, img_modulation : img_modulation + 6],
|
||||
pooled_temb[:, text_modulation : text_modulation + 6],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -591,9 +546,8 @@ class FluxTransformer2DModel(
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if is_chroma:
|
||||
start_idx = 3 * index_block
|
||||
temb = pooled_temb[:, start_idx : start_idx + 3]
|
||||
start_idx = 3 * index_block
|
||||
temb = pooled_temb[:, start_idx : start_idx + 3]
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -621,8 +575,7 @@ class FluxTransformer2DModel(
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
if is_chroma:
|
||||
temb = pooled_temb[:, -2:]
|
||||
temb = pooled_temb[:, -2:]
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user