mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fight more tests
This commit is contained in:
@@ -242,7 +242,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
@@ -16,11 +16,7 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
@@ -170,12 +166,10 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -8,11 +8,7 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
@@ -140,12 +136,10 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -15,11 +15,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
@@ -134,12 +130,10 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
Reference in New Issue
Block a user