1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix tests

This commit is contained in:
BuildTools
2025-06-13 09:41:44 -06:00
committed by DN6
parent 4e24f26d6f
commit 0978b609c8
4 changed files with 16 additions and 3 deletions

View File

@@ -675,6 +675,13 @@ class ChromaPipeline(
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
if negative_prompt is not None and isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
elif negative_prompt is not None and isinstance(negative_prompt, list):
if len(negative_prompt) == 1:
negative_prompt = [negative_prompt] * batch_size
else:
raise ValueError("prompt and negative_prompt are lists of unequal size")
else:
batch_size = prompt_embeds.shape[0]

View File

@@ -82,7 +82,7 @@ class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = ChromaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True

View File

@@ -57,7 +57,9 @@ def create_flux_ip_adapter_state_dict(model):
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
image_embed_dim=(
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
),
num_image_text_embeds=4,
)

View File

@@ -544,7 +544,11 @@ class FluxIPAdapterTesterMixin:
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_embed_dim = pipe.transformer.config.pooled_projection_dim or 768
image_embed_dim = (
pipe.transformer.config.pooled_projection_dim
if hasattr(pipe.transformer.config, "pooled_projection_dim")
else 768
)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))