From c366b5a817c0671ef9967a58cf6fcec10952f3ae Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 11 Dec 2025 13:37:06 +0530 Subject: [PATCH] update --- .../test_models_transformer_flux.py | 98 ++++++++++--------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 43e02db448..3019308831 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -45,6 +45,55 @@ from ..testing_utils import ( enable_full_determinism() +# TODO: This standalone function maintains backward compatibility with pipeline tests +# (tests/pipelines/test_pipelines_common.py) and will be refactored. +def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]: + """Create a dummy IP Adapter state dict for Flux transformer testing.""" + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + key_id += 1 + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=( + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 + ), + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict} + + class FluxTransformerTesterConfig: model_class = FluxTransformer2DModel pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" @@ -169,54 +218,7 @@ class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterM return inputs_dict def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: - from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor - - ip_cross_attn_state_dict = {} - key_id = 0 - - for name in model.attn_processors.keys(): - if name.startswith("single_transformer_blocks"): - continue - - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - sd = FluxIPAdapterAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 - ).state_dict() - ip_cross_attn_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], - f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], - } - ) - - key_id += 1 - - image_projection = ImageProjection( - cross_attention_dim=model.config["joint_attention_dim"], - image_embed_dim=( - model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 - ), - num_image_text_embeds=4, - ) - - ip_image_projection_state_dict = {} - sd = image_projection.state_dict() - ip_image_projection_state_dict.update( - { - "proj.weight": sd["image_embeds.weight"], - "proj.bias": sd["image_embeds.bias"], - "norm.weight": sd["norm.weight"], - "norm.bias": sd["norm.bias"], - } - ) - - del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict + return create_flux_ip_adapter_state_dict(model) class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):