1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[bug fix] Inpainting for MultiAdapter (#5922)

* bug in MultiAdapter for Inpainting

* adapter_input is a list for MultiAdapter

---------

Co-authored-by: andres <andres@hax.ai>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Andrés Romero
2023-11-29 15:46:26 +01:00
committed by GitHub
parent 6031ecbd23
commit 79dc7df03e

View File

@@ -1470,7 +1470,15 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(DiffusionPipeline, FromS
height, width = self._default_height_width(height, width, adapter_image)
device = self._execution_device
adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device)
if isinstance(adapter, MultiAdapter):
adapter_input = []
for one_image in adapter_image:
one_image = _preprocess_adapter_image(one_image, height, width)
one_image = one_image.to(device=device, dtype=adapter.dtype)
adapter_input.append(one_image)
else:
adapter_input = _preprocess_adapter_image(adapter_image, height, width)
adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)
original_size = original_size or (height, width)
target_size = target_size or (height, width)
@@ -1643,10 +1651,14 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(DiffusionPipeline, FromS
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Prepare added time ids & embeddings & adapter features
adapter_input = adapter_input.type(latents.dtype)
adapter_state = adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
if isinstance(adapter, MultiAdapter):
adapter_state = adapter(adapter_input, adapter_conditioning_scale)
for k, v in enumerate(adapter_state):
adapter_state[k] = v
else:
adapter_state = adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)