mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix tests
This commit is contained in:
@@ -675,6 +675,13 @@ class ChromaPipeline(
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
if negative_prompt is not None and isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
elif negative_prompt is not None and isinstance(negative_prompt, list):
|
||||
if len(negative_prompt) == 1:
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
else:
|
||||
raise ValueError("prompt and negative_prompt are lists of unequal size")
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = ChromaTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
model_split_percents = [0.8, 0.7, 0.7]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@@ -57,7 +57,9 @@ def create_flux_ip_adapter_state_dict(model):
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=model.config["pooled_projection_dim"],
|
||||
image_embed_dim=(
|
||||
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
|
||||
),
|
||||
num_image_text_embeds=4,
|
||||
)
|
||||
|
||||
|
||||
@@ -544,7 +544,11 @@ class FluxIPAdapterTesterMixin:
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
image_embed_dim = pipe.transformer.config.pooled_projection_dim or 768
|
||||
image_embed_dim = (
|
||||
pipe.transformer.config.pooled_projection_dim
|
||||
if hasattr(pipe.transformer.config, "pooled_projection_dim")
|
||||
else 768
|
||||
)
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
|
||||
Reference in New Issue
Block a user