# coding=utf-8 # Copyright 2025 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import torch from diffusers import ChromaTransformer2DModel from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 from diffusers.models.embeddings import ImageProjection from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin enable_full_determinism() def create_chroma_ip_adapter_state_dict(model): # "ip_adapter" (cross-attention weights) ip_cross_attn_state_dict = {} key_id = 0 for name in model.attn_processors.keys(): if name.startswith("single_transformer_blocks"): continue joint_attention_dim = model.config["joint_attention_dim"] hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] sd = FluxIPAdapterJointAttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 ).state_dict() ip_cross_attn_state_dict.update( { f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], } ) key_id += 1 # "image_proj" (ImageProjection layer weights) image_projection = ImageProjection( cross_attention_dim=model.config["joint_attention_dim"], image_embed_dim=model.config["pooled_projection_dim"], num_image_text_embeds=4, ) ip_image_projection_state_dict = {} sd = image_projection.state_dict() ip_image_projection_state_dict.update( { "proj.weight": sd["image_embeds.weight"], "proj.bias": sd["image_embeds.bias"], "norm.weight": sd["norm.weight"], "norm.bias": sd["norm.bias"], } ) del sd ip_state_dict = {} ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) return ip_state_dict class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = ChromaTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. model_split_percents = [0.8, 0.7, 0.7] # Skip setting testing with default: AttnProcessor uses_custom_attn_processor = True @property def dummy_input(self): batch_size = 1 num_latent_channels = 4 num_image_channels = 3 height = width = 4 sequence_length = 48 embedding_dim = 32 hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) return { "hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "img_ids": image_ids, "txt_ids": text_ids, "timestep": timestep, } @property def input_shape(self): return (16, 4) @property def output_shape(self): return (16, 4) def prepare_init_args_and_inputs_for_common(self): init_dict = { "patch_size": 1, "in_channels": 4, "num_layers": 1, "num_single_layers": 1, "attention_head_dim": 16, "num_attention_heads": 2, "joint_attention_dim": 32, "axes_dims_rope": [4, 4, 8], "approximator_num_channels": 8, "approximator_hidden_dim": 16, "approximator_layers": 1, } inputs_dict = self.dummy_input return init_dict, inputs_dict def test_deprecated_inputs_img_txt_ids_3d(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.eval() with torch.no_grad(): output_1 = model(**inputs_dict).to_tuple()[0] # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" inputs_dict["txt_ids"] = text_ids_3d inputs_dict["img_ids"] = image_ids_3d with torch.no_grad(): output_2 = model(**inputs_dict).to_tuple()[0] self.assertEqual(output_1.shape, output_2.shape) self.assertTrue( torch.allclose(output_1, output_2, atol=1e-5), msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", ) def test_gradient_checkpointing_is_applied(self): expected_set = {"ChromaTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) class ChromaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = ChromaTransformer2DModel def prepare_init_args_and_inputs_for_common(self): return ChromaTransformerTests().prepare_init_args_and_inputs_for_common() class ChromaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = ChromaTransformer2DModel def prepare_init_args_and_inputs_for_common(self): return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()