From b288684d2599fb1787bdafe8cc58967fa30f7586 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 24 Jul 2023 18:37:07 +0200 Subject: [PATCH] [SDXL] Fix sd xl encode prompt (#4237) * [SDXL] Fix sd xl encode prompt * add tests --- .../controlnet/pipeline_controlnet_sd_xl.py | 4 +-- .../pipeline_stable_diffusion_xl.py | 4 +-- .../pipeline_stable_diffusion_xl_img2img.py | 4 +-- .../pipeline_stable_diffusion_xl_inpaint.py | 4 +-- .../test_stable_diffusion_xl_img2img.py | 26 +++++++++++++++--- .../test_stable_diffusion_xl_inpaint.py | 27 ++++++++++++++++--- 6 files changed, 53 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 4f515db2a3..36978a652a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -360,7 +360,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -369,7 +369,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 9239b17523..a7a0032fb3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -375,7 +375,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -384,7 +384,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 13209b2c1b..6e3f8ff240 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -383,7 +383,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -392,7 +392,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index f36f10f98e..497195c92e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -489,7 +489,7 @@ class StableDiffusionXLInpaintPipeline( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -498,7 +498,7 @@ class StableDiffusionXLInpaintPipeline( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 1f134d01a0..f91b06d750 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -48,7 +48,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS - def get_dummy_components(self): + def get_dummy_components(self, skip_first_text_encoder=False): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), @@ -65,7 +65,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), projection_class_embeddings_input_dim=80, # 6 * 8 + 32 - cross_attention_dim=64, + cross_attention_dim=64 if not skip_first_text_encoder else 32, ) scheduler = EulerDiscreteScheduler( beta_start=0.00085, @@ -109,8 +109,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel "unet": unet, "scheduler": scheduler, "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, + "text_encoder": text_encoder if not skip_first_text_encoder else None, + "tokenizer": tokenizer if not skip_first_text_encoder else None, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, } @@ -151,6 +151,24 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_xl_refiner(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(skip_first_text_encoder=True) + + sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + + expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index e8ee1493cc..eec1e9fe4b 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -50,7 +50,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess image_latents_params = frozenset([]) - def get_dummy_components(self): + def get_dummy_components(self, skip_first_text_encoder=False): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), @@ -67,7 +67,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), projection_class_embeddings_input_dim=80, # 6 * 8 + 32 - cross_attention_dim=64, + cross_attention_dim=64 if not skip_first_text_encoder else 32, ) scheduler = EulerDiscreteScheduler( beta_start=0.00085, @@ -111,8 +111,8 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel "unet": unet, "scheduler": scheduler, "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, + "text_encoder": text_encoder if not skip_first_text_encoder else None, + "tokenizer": tokenizer if not skip_first_text_encoder else None, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, } @@ -238,6 +238,25 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 + def test_stable_diffusion_xl_refiner(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(skip_first_text_encoder=True) + + sd_pipe = self.pipeline_class(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + print(torch.from_numpy(image_slice).flatten()) + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_two_xl_mixture_of_denoiser(self): components = self.get_dummy_components() pipe_1 = StableDiffusionXLInpaintPipeline(**components).to(torch_device)