1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix for tests

This commit is contained in:
Dhruv Nair
2025-06-13 14:51:42 +02:00
parent 89faa71f04
commit 6735507705
4 changed files with 21 additions and 16 deletions

View File

@@ -418,6 +418,7 @@ class ChromaTransformer2DModel(
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
approximator_num_channels: int = 64,
approximator_hidden_dim: int = 5120,

View File

@@ -219,25 +219,23 @@ class ChromaPipeline(
text_inputs = self.tokenizer(
prompt,
padding=False,
padding=True,
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
pad_token_id = self.tokenizer.pad_token_id
text_input_ids = torch.cat(
[
text_inputs.input_ids,
torch.full((text_inputs.input_ids.size(0), 1), pad_token_id, dtype=text_inputs.input_ids.dtype),
],
dim=1,
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
output_hidden_states=False,
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
)[0]
dtype = self.text_encoder.dtype

View File

@@ -19,7 +19,6 @@ from ..test_pipelines_common import (
FasterCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
@@ -29,11 +28,10 @@ class ChromaPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
):
pipeline_class = ChromaPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
@@ -46,7 +44,7 @@ class ChromaPipelineFastTests(
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
is_guidance_distilled=False,
)
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
@@ -182,3 +180,9 @@ class ChromaPipelineFastTests(
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
@unittest.skip(
"Chroma uses Flux encode_prompt but uses CFG. This test is incompatible with the pipeline since the test case does not use a negative prompt embeds"
)
def test_encode_prompt_works_in_isolation(self):
pass

View File

@@ -307,6 +307,7 @@ class IPAdapterTesterMixin:
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
__import__("ipdb").set_trace()
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs)[0]
else:
@@ -521,7 +522,8 @@ class FluxIPAdapterTesterMixin:
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
inputs["negative_prompt"] = ""
inputs["true_cfg_scale"] = 4.0
if "true_cfg_scale" in inspect.signature(self.pipeline_class.__call__).parameters:
inputs["true_cfg_scale"] = 4.0
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs