From 5341450e149925406e010a84471664e85e9fdcd5 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 11 May 2023 10:53:45 -0700 Subject: [PATCH] Fix the mixed precision issue and add additional tests of the pipeline cuda/fp16 functionality. --- .../unidiffuser/modeling_text_decoder.py | 3 ++ .../unidiffuser/pipeline_unidiffuser.py | 10 ++--- .../pipelines/unidiffuser/test_unidiffuser.py | 43 +++++++++++++++++++ 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py index 6da23abd54..efd1a15be8 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -160,6 +160,9 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + def encode(self, prefix): + return self.encode_prefix(prefix) @torch.no_grad() def generate_captions(self, tokenizer, features, device): diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index d5f598c311..e4e50053e7 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -694,7 +694,7 @@ class UniDiffuserPipeline(DiffusionPipeline): else: # latents is assumed to have shace (B, L, D) latents = latents.repeat(num_images_per_prompt, 1, 1) - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma @@ -731,7 +731,7 @@ class UniDiffuserPipeline(DiffusionPipeline): else: # latents is assumed to have shape (B, C, H, W) latents = latents.repeat(num_prompts_per_image, 1, 1, 1) - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma @@ -753,7 +753,7 @@ class UniDiffuserPipeline(DiffusionPipeline): else: # latents is assumed to have shape (B, L, D) latents = latents.repeat(num_prompts_per_image, 1, 1) - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma @@ -1239,14 +1239,14 @@ class UniDiffuserPipeline(DiffusionPipeline): num_images_per_prompt=multiplier, seq_len=self.text_encoder_seq_len, hidden_size=self.text_encoder_hidden_size, - dtype=torch.float32, # TODO: Placeholder, need to determine correct thing to do for dtype + dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision device=device, generator=generator, latents=prompt_latents, ) if reduce_text_emb_dim: - prompt_embeds = self.text_decoder.encode_prefix(prompt_embeds) + prompt_embeds = self.text_decoder.encode(prompt_embeds) # 4. Encode image, if available; otherwise prepare image latents if mode in ["img2text"]: diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 69d318bee0..967284876a 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -432,6 +432,49 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_text_prefix = " no no no " assert text[0][:10] == expected_text_prefix + @require_torch_gpu + def test_unidiffuser_default_text2img_v1_cuda_fp16(self): + device = "cuda" + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-test-v1", torch_dtype=torch.float16) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["image"] + inputs["data_type"] = 1 + sample = unidiffuser_pipe(**inputs) + image = sample.images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5757, 0.6270, 0.6567, 0.4966, 0.4639, 0.5664, 0.5259, 0.5068, 0.5713]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + @require_torch_gpu + def test_unidiffuser_default_img2text_v1_cuda_fp16(self): + device = "cuda" + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-test-v1", torch_dtype=torch.float16) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + inputs["data_type"] = 1 + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + @slow @require_torch_gpu