From 16a3dad474dad00f8e4071d699e1562471a2dacd Mon Sep 17 00:00:00 2001 From: Sangwon Lee Date: Wed, 21 Aug 2024 06:54:27 +0900 Subject: [PATCH] Fix StableDiffusionXLPAGInpaintPipeline (#9128) --- src/diffusers/pipelines/auto_pipeline.py | 3 ++- .../pipelines/pag/pipeline_pag_sd_xl_inpaint.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index e4798fb990..e756bad3b0 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -955,7 +955,8 @@ class AutoPipelineForInpainting(ConfigMixin): if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") + to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" + orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace) inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index 64aff497a5..09c3a7029c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -1471,6 +1471,14 @@ class StableDiffusionXLPAGInpaintPipeline( generator, self.do_classifier_free_guidance, ) + if self.do_perturbed_attention_guidance: + if self.do_classifier_free_guidance: + mask, _ = mask.chunk(2) + masked_image_latents, _ = masked_image_latents.chunk(2) + mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance) + masked_image_latents = self._prepare_perturbed_attention_guidance( + masked_image_latents, masked_image_latents, self.do_classifier_free_guidance + ) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: @@ -1659,10 +1667,10 @@ class StableDiffusionXLPAGInpaintPipeline( if num_channels_unet == 4: init_latents_proper = image_latents - if self.do_classifier_free_guidance: - init_mask, _ = mask.chunk(2) + if self.do_perturbed_attention_guidance: + init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2) else: - init_mask = mask + init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1]