mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* update
* update
* update
* update
* update
* merge main
* Revert "merge main"
This reverts commit 65efbcead5.
184 lines
6.4 KiB
Python
184 lines
6.4 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from diffusers import ChromaTransformer2DModel
|
|
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
|
from diffusers.models.embeddings import ImageProjection
|
|
|
|
from ...testing_utils import enable_full_determinism, torch_device
|
|
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
|
|
|
|
|
enable_full_determinism()
|
|
|
|
|
|
def create_chroma_ip_adapter_state_dict(model):
|
|
# "ip_adapter" (cross-attention weights)
|
|
ip_cross_attn_state_dict = {}
|
|
key_id = 0
|
|
|
|
for name in model.attn_processors.keys():
|
|
if name.startswith("single_transformer_blocks"):
|
|
continue
|
|
|
|
joint_attention_dim = model.config["joint_attention_dim"]
|
|
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
|
sd = FluxIPAdapterJointAttnProcessor2_0(
|
|
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
|
).state_dict()
|
|
ip_cross_attn_state_dict.update(
|
|
{
|
|
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
|
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
|
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
|
|
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
|
}
|
|
)
|
|
|
|
key_id += 1
|
|
|
|
# "image_proj" (ImageProjection layer weights)
|
|
|
|
image_projection = ImageProjection(
|
|
cross_attention_dim=model.config["joint_attention_dim"],
|
|
image_embed_dim=model.config["pooled_projection_dim"],
|
|
num_image_text_embeds=4,
|
|
)
|
|
|
|
ip_image_projection_state_dict = {}
|
|
sd = image_projection.state_dict()
|
|
ip_image_projection_state_dict.update(
|
|
{
|
|
"proj.weight": sd["image_embeds.weight"],
|
|
"proj.bias": sd["image_embeds.bias"],
|
|
"norm.weight": sd["norm.weight"],
|
|
"norm.bias": sd["norm.bias"],
|
|
}
|
|
)
|
|
|
|
del sd
|
|
ip_state_dict = {}
|
|
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
|
return ip_state_dict
|
|
|
|
|
|
class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
model_class = ChromaTransformer2DModel
|
|
main_input_name = "hidden_states"
|
|
# We override the items here because the transformer under consideration is small.
|
|
model_split_percents = [0.8, 0.7, 0.7]
|
|
|
|
# Skip setting testing with default: AttnProcessor
|
|
uses_custom_attn_processor = True
|
|
|
|
@property
|
|
def dummy_input(self):
|
|
batch_size = 1
|
|
num_latent_channels = 4
|
|
num_image_channels = 3
|
|
height = width = 4
|
|
sequence_length = 48
|
|
embedding_dim = 32
|
|
|
|
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
|
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
|
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
|
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
|
|
|
return {
|
|
"hidden_states": hidden_states,
|
|
"encoder_hidden_states": encoder_hidden_states,
|
|
"img_ids": image_ids,
|
|
"txt_ids": text_ids,
|
|
"timestep": timestep,
|
|
}
|
|
|
|
@property
|
|
def input_shape(self):
|
|
return (16, 4)
|
|
|
|
@property
|
|
def output_shape(self):
|
|
return (16, 4)
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
init_dict = {
|
|
"patch_size": 1,
|
|
"in_channels": 4,
|
|
"num_layers": 1,
|
|
"num_single_layers": 1,
|
|
"attention_head_dim": 16,
|
|
"num_attention_heads": 2,
|
|
"joint_attention_dim": 32,
|
|
"axes_dims_rope": [4, 4, 8],
|
|
"approximator_num_channels": 8,
|
|
"approximator_hidden_dim": 16,
|
|
"approximator_layers": 1,
|
|
}
|
|
|
|
inputs_dict = self.dummy_input
|
|
return init_dict, inputs_dict
|
|
|
|
def test_deprecated_inputs_img_txt_ids_3d(self):
|
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
|
model = self.model_class(**init_dict)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
output_1 = model(**inputs_dict).to_tuple()[0]
|
|
|
|
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
|
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
|
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
|
|
|
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
|
|
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
|
|
|
|
inputs_dict["txt_ids"] = text_ids_3d
|
|
inputs_dict["img_ids"] = image_ids_3d
|
|
|
|
with torch.no_grad():
|
|
output_2 = model(**inputs_dict).to_tuple()[0]
|
|
|
|
self.assertEqual(output_1.shape, output_2.shape)
|
|
self.assertTrue(
|
|
torch.allclose(output_1, output_2, atol=1e-5),
|
|
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
|
)
|
|
|
|
def test_gradient_checkpointing_is_applied(self):
|
|
expected_set = {"ChromaTransformer2DModel"}
|
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
|
|
|
|
|
class ChromaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
|
model_class = ChromaTransformer2DModel
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
|
|
|
|
|
|
class ChromaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
|
model_class = ChromaTransformer2DModel
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
|