diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 5cdc381918..542f38b6ec 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2137,9 +2137,18 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) + variant = "chroma" if "distilled_guidance_layer.in_proj.weight" in checkpoint else "flux" + for k in keys: if "model.diffusion_model." in k: checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + if variant == "chroma" and "distilled_guidance_layer." in k: + new_key = k + if k.startswith("distilled_guidance_layer.norms"): + new_key = k.replace(".scale", ".weight") + elif k.startswith("distilled_guidance_layer.layer"): + new_key = k.replace("in_layer", "linear_1").replace("out_layer", "linear_2") + converted_state_dict[new_key] = checkpoint.pop(k) num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 @@ -2153,39 +2162,48 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): new_weight = torch.cat([scale, shift], dim=0) return new_weight - ## time_text_embed.timestep_embedder <- time_in - converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "time_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") - converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "time_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") + if variant == "flux": + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "time_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop( + "time_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "time_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop( + "time_in.out_layer.bias" + ) - ## time_text_embed.text_embedder <- vector_in - converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") - converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") - converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( - "vector_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") + ## time_text_embed.text_embedder <- vector_in + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop( + "vector_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( + "vector_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop( + "vector_in.out_layer.bias" + ) - # guidance - has_guidance = any("guidance" in k for k in checkpoint) - if has_guidance: - converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( - "guidance_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( - "guidance_in.in_layer.bias" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( - "guidance_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( - "guidance_in.out_layer.bias" - ) + # guidance + has_guidance = any("guidance" in k for k in checkpoint) + if has_guidance: + converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( + "guidance_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( + "guidance_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( + "guidance_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( + "guidance_in.out_layer.bias" + ) # context_embedder converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") @@ -2199,20 +2217,21 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): for i in range(num_layers): block_prefix = f"transformer_blocks.{i}." # norms. - ## norm1 - converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.bias" - ) - ## norm1_context - converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.bias" - ) + if variant == "flux": + ## norm1 + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.bias" + ) + ## norm1_context + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.bias" + ) # Q, K, V sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) context_q, context_k, context_v = torch.chunk( @@ -2285,13 +2304,15 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): # single transformer blocks for i in range(num_single_layers): block_prefix = f"single_transformer_blocks.{i}." - # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.bias" - ) + + if variant == "flux": + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.bias" + ) # Q, K, V, mlp mlp_hidden_dim = int(inner_dim * mlp_ratio) split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) @@ -2320,12 +2341,14 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.weight") - ) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.bias") - ) + + if variant == "flux": + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.bias") + ) return converted_state_dict diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 29209fc9f1..8aa2ea5841 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1643,17 +1643,10 @@ class CombinedTimestepTextProjChromaEmbeddings(nn.Module): self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0) self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0) - self.embedder = ChromaApproximator( - in_dim=factor * 4, - out_dim=out_dim, - hidden_dim=hidden_dim, - n_layers=n_layers, - ) - self.embedding_dim = embedding_dim self.register_buffer( "mod_proj", - get_timestep_embedding(torch.arange(out_dim), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0), + get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ), persistent=False, ) @@ -1661,24 +1654,16 @@ class CombinedTimestepTextProjChromaEmbeddings(nn.Module): self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor ) -> torch.Tensor: mod_index_length = self.mod_proj.shape[0] - timesteps_proj = self.time_proj(timestep) + self.time_proj(pooled_projections) - if guidance is not None: - guidance_proj = self.guidance_proj(guidance) - else: - guidance_proj = torch.zeros( - (self.embedding_dim, self.guidance_proj.num_channels), - dtype=timesteps_proj.dtype, - device=timesteps_proj.device, - ) + timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype) + guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device) mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device) timestep_guidance = ( torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1) ) - input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1) - conditioning = self.embedder(input_vec) + input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1) - return conditioning + return input_vec class CogView3CombinedTimestepSizeEmbeddings(nn.Module): diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 93e80cd4da..f2b71bb688 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -206,7 +206,7 @@ class AdaLayerNormZeroPruned(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if self.emb is not None: emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) - scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp = emb.chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp @@ -267,7 +267,7 @@ class AdaLayerNormZeroSinglePruned(nn.Module): x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - scale_msa, shift_msa, gate_msa = emb.chunk(3, dim=1) + shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa @@ -413,7 +413,7 @@ class AdaLayerNormContinuousPruned(nn.Module): def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) - shift, scale = torch.chunk(emb.to(x.dtype), 2, dim=1) + shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 21bf42b57d..c542bcaacc 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -37,6 +37,7 @@ from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjChromaEmbeddings, CombinedTimestepTextProjEmbeddings, + ChromaApproximator, FluxPosEmbed, ) from ..modeling_outputs import Transformer2DModelOutput @@ -308,6 +309,7 @@ class FluxTransformer2DModel( embedding_dim=self.inner_dim, n_layers=approximator_layers, ) + self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5) else: raise ValueError(INVALID_VARIANT_ERRMSG) @@ -518,7 +520,8 @@ class FluxTransformer2DModel( else self.time_text_embed(timestep, guidance, pooled_projections) ) else: - pooled_temb = self.time_text_embed(timestep, guidance, pooled_projections) + input_vec = self.time_text_embed(timestep, guidance, pooled_projections) + pooled_temb = self.distilled_guidance_layer(input_vec) encoder_hidden_states = self.context_embedder(encoder_hidden_states) @@ -545,10 +548,16 @@ class FluxTransformer2DModel( for index_block, block in enumerate(self.transformer_blocks): if is_chroma: - start_idx1 = 3 * len(self.single_transformer_blocks) + 6 * index_block - start_idx2 = start_idx1 + 6 * len(self.transformer_blocks) + img_offset = 3 * len(self.single_transformer_blocks) + txt_offset = img_offset + 6 * len(self.transformer_blocks) + img_modulation = img_offset + 6 * index_block + text_modulation = txt_offset + 6 * index_block temb = torch.cat( - (pooled_temb[:, start_idx1 : start_idx1 + 6], pooled_temb[:, start_idx2 : start_idx2 + 6]), dim=1 + ( + pooled_temb[:, img_modulation : img_modulation + 6], + pooled_temb[:, text_modulation : text_modulation + 6], + ), + dim=1, ) if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(