mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[bug fix] Inpainting for MultiAdapter (#5922)
* bug in MultiAdapter for Inpainting * adapter_input is a list for MultiAdapter --------- Co-authored-by: andres <andres@hax.ai> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -1470,7 +1470,15 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(DiffusionPipeline, FromS
|
||||
height, width = self._default_height_width(height, width, adapter_image)
|
||||
device = self._execution_device
|
||||
|
||||
adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device)
|
||||
if isinstance(adapter, MultiAdapter):
|
||||
adapter_input = []
|
||||
for one_image in adapter_image:
|
||||
one_image = _preprocess_adapter_image(one_image, height, width)
|
||||
one_image = one_image.to(device=device, dtype=adapter.dtype)
|
||||
adapter_input.append(one_image)
|
||||
else:
|
||||
adapter_input = _preprocess_adapter_image(adapter_image, height, width)
|
||||
adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
@@ -1643,10 +1651,14 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(DiffusionPipeline, FromS
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 10. Prepare added time ids & embeddings & adapter features
|
||||
adapter_input = adapter_input.type(latents.dtype)
|
||||
adapter_state = adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
if isinstance(adapter, MultiAdapter):
|
||||
adapter_state = adapter(adapter_input, adapter_conditioning_scale)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v
|
||||
else:
|
||||
adapter_state = adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
if num_images_per_prompt > 1:
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
|
||||
|
||||
Reference in New Issue
Block a user