1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Fix the mixed precision issue and add additional tests of the pipeline cuda/fp16 functionality.

This commit is contained in:
Daniel Gu
2023-05-11 10:53:45 -07:00
parent 1bc2b91dfc
commit 5341450e14
3 changed files with 51 additions and 5 deletions

View File

@@ -160,6 +160,9 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def encode(self, prefix):
return self.encode_prefix(prefix)
@torch.no_grad()
def generate_captions(self, tokenizer, features, device):

View File

@@ -694,7 +694,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
else:
# latents is assumed to have shace (B, L, D)
latents = latents.repeat(num_images_per_prompt, 1, 1)
latents = latents.to(device)
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
@@ -731,7 +731,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
else:
# latents is assumed to have shape (B, C, H, W)
latents = latents.repeat(num_prompts_per_image, 1, 1, 1)
latents = latents.to(device)
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
@@ -753,7 +753,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
else:
# latents is assumed to have shape (B, L, D)
latents = latents.repeat(num_prompts_per_image, 1, 1)
latents = latents.to(device)
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
@@ -1239,14 +1239,14 @@ class UniDiffuserPipeline(DiffusionPipeline):
num_images_per_prompt=multiplier,
seq_len=self.text_encoder_seq_len,
hidden_size=self.text_encoder_hidden_size,
dtype=torch.float32, # TODO: Placeholder, need to determine correct thing to do for dtype
dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision
device=device,
generator=generator,
latents=prompt_latents,
)
if reduce_text_emb_dim:
prompt_embeds = self.text_decoder.encode_prefix(prompt_embeds)
prompt_embeds = self.text_decoder.encode(prompt_embeds)
# 4. Encode image, if available; otherwise prepare image latents
if mode in ["img2text"]:

View File

@@ -432,6 +432,49 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
@require_torch_gpu
def test_unidiffuser_default_text2img_v1_cuda_fp16(self):
device = "cuda"
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-test-v1", torch_dtype=torch.float16)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text2img'
unidiffuser_pipe.set_text_to_image_mode()
assert unidiffuser_pipe.mode == "text2img"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["image"]
inputs["data_type"] = 1
sample = unidiffuser_pipe(**inputs)
image = sample.images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.5757, 0.6270, 0.6567, 0.4966, 0.4639, 0.5664, 0.5259, 0.5068, 0.5713])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
@require_torch_gpu
def test_unidiffuser_default_img2text_v1_cuda_fp16(self):
device = "cuda"
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-test-v1", torch_dtype=torch.float16)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img2text'
unidiffuser_pipe.set_image_to_text_mode()
assert unidiffuser_pipe.mode == "img2text"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
inputs["data_type"] = 1
text = unidiffuser_pipe(**inputs).text
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
@slow
@require_torch_gpu