mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Get chroma to a functioning state
This commit is contained in:
@@ -2137,9 +2137,18 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
variant = "chroma" if "distilled_guidance_layer.in_proj.weight" in checkpoint else "flux"
|
||||
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
if variant == "chroma" and "distilled_guidance_layer." in k:
|
||||
new_key = k
|
||||
if k.startswith("distilled_guidance_layer.norms"):
|
||||
new_key = k.replace(".scale", ".weight")
|
||||
elif k.startswith("distilled_guidance_layer.layer"):
|
||||
new_key = k.replace("in_layer", "linear_1").replace("out_layer", "linear_2")
|
||||
converted_state_dict[new_key] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
||||
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
||||
@@ -2153,39 +2162,48 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"time_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"time_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
|
||||
if variant == "flux":
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"time_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop(
|
||||
"time_in.in_layer.bias"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"time_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop(
|
||||
"time_in.out_layer.bias"
|
||||
)
|
||||
|
||||
## time_text_embed.text_embedder <- vector_in
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"vector_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
|
||||
## time_text_embed.text_embedder <- vector_in
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"vector_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"vector_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop(
|
||||
"vector_in.out_layer.bias"
|
||||
)
|
||||
|
||||
# guidance
|
||||
has_guidance = any("guidance" in k for k in checkpoint)
|
||||
if has_guidance:
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"guidance_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
|
||||
"guidance_in.in_layer.bias"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"guidance_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
|
||||
"guidance_in.out_layer.bias"
|
||||
)
|
||||
# guidance
|
||||
has_guidance = any("guidance" in k for k in checkpoint)
|
||||
if has_guidance:
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"guidance_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
|
||||
"guidance_in.in_layer.bias"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"guidance_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
|
||||
"guidance_in.out_layer.bias"
|
||||
)
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
|
||||
@@ -2199,20 +2217,21 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
# norms.
|
||||
## norm1
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.bias"
|
||||
)
|
||||
## norm1_context
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.bias"
|
||||
)
|
||||
if variant == "flux":
|
||||
## norm1
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.bias"
|
||||
)
|
||||
## norm1_context
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.bias"
|
||||
)
|
||||
# Q, K, V
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
@@ -2285,13 +2304,15 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
# single transformer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
# norm.linear <- single_blocks.0.modulation.lin
|
||||
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.bias"
|
||||
)
|
||||
|
||||
if variant == "flux":
|
||||
# norm.linear <- single_blocks.0.modulation.lin
|
||||
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.bias"
|
||||
)
|
||||
# Q, K, V, mlp
|
||||
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||
@@ -2320,12 +2341,14 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
|
||||
)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
|
||||
)
|
||||
|
||||
if variant == "flux":
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
|
||||
)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
@@ -1643,17 +1643,10 @@ class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
|
||||
|
||||
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.embedder = ChromaApproximator(
|
||||
in_dim=factor * 4,
|
||||
out_dim=out_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
n_layers=n_layers,
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.register_buffer(
|
||||
"mod_proj",
|
||||
get_timestep_embedding(torch.arange(out_dim), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
|
||||
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
@@ -1661,24 +1654,16 @@ class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
|
||||
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
mod_index_length = self.mod_proj.shape[0]
|
||||
timesteps_proj = self.time_proj(timestep) + self.time_proj(pooled_projections)
|
||||
if guidance is not None:
|
||||
guidance_proj = self.guidance_proj(guidance)
|
||||
else:
|
||||
guidance_proj = torch.zeros(
|
||||
(self.embedding_dim, self.guidance_proj.num_channels),
|
||||
dtype=timesteps_proj.dtype,
|
||||
device=timesteps_proj.device,
|
||||
)
|
||||
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
|
||||
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
|
||||
|
||||
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
|
||||
timestep_guidance = (
|
||||
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
|
||||
)
|
||||
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
|
||||
conditioning = self.embedder(input_vec)
|
||||
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
|
||||
|
||||
return conditioning
|
||||
return input_vec
|
||||
|
||||
|
||||
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
||||
|
||||
@@ -206,7 +206,7 @@ class AdaLayerNormZeroPruned(nn.Module):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.emb is not None:
|
||||
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
||||
scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp = emb.chunk(6, dim=1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
@@ -267,7 +267,7 @@ class AdaLayerNormZeroSinglePruned(nn.Module):
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
scale_msa, shift_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
@@ -413,7 +413,7 @@ class AdaLayerNormContinuousPruned(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
shift, scale = torch.chunk(emb.to(x.dtype), 2, dim=1)
|
||||
shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjChromaEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
ChromaApproximator,
|
||||
FluxPosEmbed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -308,6 +309,7 @@ class FluxTransformer2DModel(
|
||||
embedding_dim=self.inner_dim,
|
||||
n_layers=approximator_layers,
|
||||
)
|
||||
self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5)
|
||||
else:
|
||||
raise ValueError(INVALID_VARIANT_ERRMSG)
|
||||
|
||||
@@ -518,7 +520,8 @@ class FluxTransformer2DModel(
|
||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
)
|
||||
else:
|
||||
pooled_temb = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
input_vec = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
pooled_temb = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
@@ -545,10 +548,16 @@ class FluxTransformer2DModel(
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if is_chroma:
|
||||
start_idx1 = 3 * len(self.single_transformer_blocks) + 6 * index_block
|
||||
start_idx2 = start_idx1 + 6 * len(self.transformer_blocks)
|
||||
img_offset = 3 * len(self.single_transformer_blocks)
|
||||
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
||||
img_modulation = img_offset + 6 * index_block
|
||||
text_modulation = txt_offset + 6 * index_block
|
||||
temb = torch.cat(
|
||||
(pooled_temb[:, start_idx1 : start_idx1 + 6], pooled_temb[:, start_idx2 : start_idx2 + 6]), dim=1
|
||||
(
|
||||
pooled_temb[:, img_modulation : img_modulation + 6],
|
||||
pooled_temb[:, text_modulation : text_modulation + 6],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
|
||||
Reference in New Issue
Block a user