mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix for tests
This commit is contained in:
@@ -418,6 +418,7 @@ class ChromaTransformer2DModel(
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
|
||||
approximator_num_channels: int = 64,
|
||||
approximator_hidden_dim: int = 5120,
|
||||
|
||||
@@ -219,25 +219,23 @@ class ChromaPipeline(
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding=False,
|
||||
padding=True,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
text_input_ids = torch.cat(
|
||||
[
|
||||
text_inputs.input_ids,
|
||||
torch.full((text_inputs.input_ids.size(0), 1), pad_token_id, dtype=text_inputs.input_ids.dtype),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask.clone()
|
||||
|
||||
# Chroma requires the attention mask to include one padding token
|
||||
seq_lengths = attention_mask.sum(dim=1)
|
||||
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=False,
|
||||
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
||||
)[0]
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
|
||||
@@ -19,7 +19,6 @@ from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
@@ -29,11 +28,10 @@ class ChromaPipelineFastTests(
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
):
|
||||
pipeline_class = ChromaPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
# there is no xformers processor for Flux
|
||||
@@ -46,7 +44,7 @@ class ChromaPipelineFastTests(
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
is_guidance_distilled=True,
|
||||
is_guidance_distilled=False,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
@@ -182,3 +180,9 @@ class ChromaPipelineFastTests(
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
|
||||
@unittest.skip(
|
||||
"Chroma uses Flux encode_prompt but uses CFG. This test is incompatible with the pipeline since the test case does not use a negative prompt embeds"
|
||||
)
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@@ -307,6 +307,7 @@ class IPAdapterTesterMixin:
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
__import__("ipdb").set_trace()
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs)[0]
|
||||
else:
|
||||
@@ -521,7 +522,8 @@ class FluxIPAdapterTesterMixin:
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
inputs["negative_prompt"] = ""
|
||||
inputs["true_cfg_scale"] = 4.0
|
||||
if "true_cfg_scale" in inspect.signature(self.pipeline_class.__call__).parameters:
|
||||
inputs["true_cfg_scale"] = 4.0
|
||||
inputs["output_type"] = "np"
|
||||
inputs["return_dict"] = False
|
||||
return inputs
|
||||
|
||||
Reference in New Issue
Block a user