1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Get chroma to a functioning state

This commit is contained in:
Ivan DiLernia
2025-06-09 11:48:50 -04:00
parent 373106cedb
commit 104e1636b2
4 changed files with 102 additions and 85 deletions

View File

@@ -2137,9 +2137,18 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
variant = "chroma" if "distilled_guidance_layer.in_proj.weight" in checkpoint else "flux"
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
if variant == "chroma" and "distilled_guidance_layer." in k:
new_key = k
if k.startswith("distilled_guidance_layer.norms"):
new_key = k.replace(".scale", ".weight")
elif k.startswith("distilled_guidance_layer.layer"):
new_key = k.replace("in_layer", "linear_1").replace("out_layer", "linear_2")
converted_state_dict[new_key] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
@@ -2153,39 +2162,48 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
## time_text_embed.timestep_embedder <- time_in
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"time_in.in_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"time_in.out_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
if variant == "flux":
## time_text_embed.timestep_embedder <- time_in
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"time_in.in_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop(
"time_in.in_layer.bias"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"time_in.out_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop(
"time_in.out_layer.bias"
)
## time_text_embed.text_embedder <- vector_in
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
"vector_in.out_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
## time_text_embed.text_embedder <- vector_in
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop(
"vector_in.in_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
"vector_in.out_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop(
"vector_in.out_layer.bias"
)
# guidance
has_guidance = any("guidance" in k for k in checkpoint)
if has_guidance:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
"guidance_in.in_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
"guidance_in.in_layer.bias"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
"guidance_in.out_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
"guidance_in.out_layer.bias"
)
# guidance
has_guidance = any("guidance" in k for k in checkpoint)
if has_guidance:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
"guidance_in.in_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
"guidance_in.in_layer.bias"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
"guidance_in.out_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
"guidance_in.out_layer.bias"
)
# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
@@ -2199,20 +2217,21 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms.
## norm1
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.bias"
)
## norm1_context
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.bias"
)
if variant == "flux":
## norm1
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.bias"
)
## norm1_context
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.bias"
)
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
@@ -2285,13 +2304,15 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.bias"
)
if variant == "flux":
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.bias"
)
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
@@ -2320,12 +2341,14 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
)
if variant == "flux":
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
)
return converted_state_dict

View File

@@ -1643,17 +1643,10 @@ class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.embedder = ChromaApproximator(
in_dim=factor * 4,
out_dim=out_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
)
self.embedding_dim = embedding_dim
self.register_buffer(
"mod_proj",
get_timestep_embedding(torch.arange(out_dim), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
persistent=False,
)
@@ -1661,24 +1654,16 @@ class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
timesteps_proj = self.time_proj(timestep) + self.time_proj(pooled_projections)
if guidance is not None:
guidance_proj = self.guidance_proj(guidance)
else:
guidance_proj = torch.zeros(
(self.embedding_dim, self.guidance_proj.num_channels),
dtype=timesteps_proj.dtype,
device=timesteps_proj.device,
)
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
conditioning = self.embedder(input_vec)
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
return conditioning
return input_vec
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):

View File

@@ -206,7 +206,7 @@ class AdaLayerNormZeroPruned(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp = emb.chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
@@ -267,7 +267,7 @@ class AdaLayerNormZeroSinglePruned(nn.Module):
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
scale_msa, shift_msa, gate_msa = emb.chunk(3, dim=1)
shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
@@ -413,7 +413,7 @@ class AdaLayerNormContinuousPruned(nn.Module):
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
shift, scale = torch.chunk(emb.to(x.dtype), 2, dim=1)
shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x

View File

@@ -37,6 +37,7 @@ from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjChromaEmbeddings,
CombinedTimestepTextProjEmbeddings,
ChromaApproximator,
FluxPosEmbed,
)
from ..modeling_outputs import Transformer2DModelOutput
@@ -308,6 +309,7 @@ class FluxTransformer2DModel(
embedding_dim=self.inner_dim,
n_layers=approximator_layers,
)
self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)
@@ -518,7 +520,8 @@ class FluxTransformer2DModel(
else self.time_text_embed(timestep, guidance, pooled_projections)
)
else:
pooled_temb = self.time_text_embed(timestep, guidance, pooled_projections)
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
pooled_temb = self.distilled_guidance_layer(input_vec)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
@@ -545,10 +548,16 @@ class FluxTransformer2DModel(
for index_block, block in enumerate(self.transformer_blocks):
if is_chroma:
start_idx1 = 3 * len(self.single_transformer_blocks) + 6 * index_block
start_idx2 = start_idx1 + 6 * len(self.transformer_blocks)
img_offset = 3 * len(self.single_transformer_blocks)
txt_offset = img_offset + 6 * len(self.transformer_blocks)
img_modulation = img_offset + 6 * index_block
text_modulation = txt_offset + 6 * index_block
temb = torch.cat(
(pooled_temb[:, start_idx1 : start_idx1 + 6], pooled_temb[:, start_idx2 : start_idx2 + 6]), dim=1
(
pooled_temb[:, img_modulation : img_modulation + 6],
pooled_temb[:, text_modulation : text_modulation + 6],
),
dim=1,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(