1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

make style

This commit is contained in:
Daniel Gu
2023-05-16 13:00:58 -07:00
parent e56fab2def
commit ecaf07f673

View File

@@ -540,7 +540,7 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
for latent_name, latent_tensor in latents.items():
inputs[latent_name] = latent_tensor
return inputs
def get_fixed_latents(self, device, seed=0):
if type(device) == str:
device = torch.device(device)
@@ -584,7 +584,7 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
expected_text_prefix = "A living room"
assert text[0][:len(expected_text_prefix)] == expected_text_prefix
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
def test_unidiffuser_default_text2img_v1(self):
pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-diffusers")
@@ -614,7 +614,7 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
text = sample.text
expected_text_prefix = "T CL CL CL "
assert text[0][:len(expected_text_prefix)] == expected_text_prefix
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
def test_unidiffuser_default_joint_v1_fp16(self):
pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-diffusers", torch_dtype=torch.float16)
@@ -639,7 +639,7 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
expected_text_prefix = "A living room"
assert text[0][:len(expected_text_prefix)] == expected_text_prefix
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
def test_unidiffuser_default_text2img_v1_fp16(self):
pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-diffusers", torch_dtype=torch.float16)
@@ -671,4 +671,4 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
print(f"Text: {text}")
expected_text_prefix = "T CL CL CL "
assert text[0][:len(expected_text_prefix)] == expected_text_prefix
assert text[0][: len(expected_text_prefix)] == expected_text_prefix