1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix StableDiffusionXLPAGInpaintPipeline (#9128)

This commit is contained in:
Sangwon Lee
2024-08-21 06:54:27 +09:00
committed by GitHub
parent 21682bab7e
commit 16a3dad474
2 changed files with 13 additions and 4 deletions

View File

@@ -955,7 +955,8 @@ class AutoPipelineForInpainting(ConfigMixin):
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline")
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace)
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)

View File

@@ -1471,6 +1471,14 @@ class StableDiffusionXLPAGInpaintPipeline(
generator,
self.do_classifier_free_guidance,
)
if self.do_perturbed_attention_guidance:
if self.do_classifier_free_guidance:
mask, _ = mask.chunk(2)
masked_image_latents, _ = masked_image_latents.chunk(2)
mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance)
masked_image_latents = self._prepare_perturbed_attention_guidance(
masked_image_latents, masked_image_latents, self.do_classifier_free_guidance
)
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
@@ -1659,10 +1667,10 @@ class StableDiffusionXLPAGInpaintPipeline(
if num_channels_unet == 4:
init_latents_proper = image_latents
if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
if self.do_perturbed_attention_guidance:
init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2)
else:
init_mask = mask
init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]