1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Tests] fix some fast gpu tests. (#9379)

fix some fast gpu tests.
This commit is contained in:
Sayak Paul
2024-09-11 06:50:02 +05:30
committed by GitHub
parent f28a8c257a
commit adf1f911f0
4 changed files with 5 additions and 2 deletions

View File

@@ -1597,6 +1597,7 @@ def main(args):
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=prompts,
)
else:
@@ -1606,6 +1607,7 @@ def main(args):
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=args.instance_prompt,
)

View File

@@ -465,6 +465,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)

View File

@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxImg2ImgPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)

View File

@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxInpaintPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)