1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Add fast test for CUDA/fp16 model behavior (currently failing).

This commit is contained in:
Daniel Gu
2023-05-11 10:05:04 -07:00
parent f36df41999
commit 1bc2b91dfc

View File

@@ -111,12 +111,18 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_fixed_latents(self, device, seed=0):
if type(device) == str:
device = torch.device(device)
generator = torch.Generator(device=device).manual_seed(seed)
latent_device = torch.device("cpu")
generator = torch.Generator(device=latent_device).manual_seed(seed)
# Hardcode the shapes for now.
prompt_latents = randn_tensor((1, 77, 32), generator=generator, device=device, dtype=torch.float32)
vae_latents = randn_tensor((1, 4, 16, 16), generator=generator, device=device, dtype=torch.float32)
clip_latents = randn_tensor((1, 1, 32), generator=generator, device=device, dtype=torch.float32)
# Move latents onto desired device.
prompt_latents = prompt_latents.to(device)
vae_latents = vae_latents.to(device)
clip_latents = clip_latents.to(device)
latents = {
"prompt_latents": prompt_latents,
"vae_latents": vae_latents,
@@ -132,7 +138,6 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg",
)
image = image.resize((32, 32))
latents = self.get_fixed_latents(device, seed=seed)
if str(device).startswith("mps"):
@@ -399,6 +404,34 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert len(text) == 3
@require_torch_gpu
def test_unidiffuser_default_joint_v1_cuda_fp16(self):
device = "cuda"
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-test-v1", torch_dtype=torch.float16)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'joint'
unidiffuser_pipe.set_joint_mode()
assert unidiffuser_pipe.mode == "joint"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
inputs["data_type"] = 1
sample = unidiffuser_pipe(**inputs)
image = sample.images
text = sample.text
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.5762, 0.6270, 0.6572, 0.4966, 0.4639, 0.5664, 0.5254, 0.5068, 0.5713])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
@slow
@require_torch_gpu