1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix Qwen Edit Plus modular for multi-image input (#12601)

* try to fix qwen edit plus multi images (modular)

* up

* up

* test

* up

* up
This commit is contained in:
Sayak Paul
2025-12-10 04:08:30 +08:00
committed by GitHub
parent 07ea0786e8
commit 8b4722de57
5 changed files with 247 additions and 35 deletions

View File

@@ -610,7 +610,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
# for edit, image size can be different from the target size (height/width)
block_state.img_shapes = [
[
(
@@ -640,6 +639,37 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
return components, state
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
model_name = "qwenimage-edit-plus"
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_scale_factor = components.vae_scale_factor
block_state.img_shapes = [
[
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
*[
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
],
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
self.set_block_state(state, block_state)
return components, state
## ControlNet inputs for denoiser
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
model_name = "qwenimage"

View File

@@ -330,7 +330,7 @@ class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
output_name: str = "resized_image",
vae_image_output_name: str = "vae_image",
):
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.
This block resizes an input image or a list input images and exposes the resized result under configurable
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
@@ -809,9 +809,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="processed_image"),
]
return [OutputParam(name="processed_image")]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
@@ -851,7 +849,10 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
model_name = "qwenimage-edit-plus"
vae_image_size = 1024 * 1024
def __init__(self):
self.vae_image_size = 1024 * 1024
super().__init__()
@property
def description(self) -> str:
@@ -868,6 +869,7 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
if block_state.vae_image is None and block_state.image is None:
raise ValueError("`vae_image` and `image` cannot be None at the same time")
vae_image_sizes = None
if block_state.vae_image is None:
image = block_state.image
self.check_inputs(
@@ -879,12 +881,19 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
image=image, height=height, width=width
)
else:
width, height = block_state.vae_image[0].size
image = block_state.vae_image
# QwenImage Edit Plus can allow multiple input images with varied resolutions
processed_images = []
vae_image_sizes = []
for img in block_state.vae_image:
width, height = img.size
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
vae_image_sizes.append((vae_width, vae_height))
processed_images.append(
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
)
block_state.processed_image = processed_images
block_state.processed_image = components.image_processor.preprocess(
image=image, height=height, width=width
)
block_state.vae_image_sizes = vae_image_sizes
self.set_block_state(state, block_state)
return components, state
@@ -926,17 +935,12 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
@property
def expected_components(self) -> List[ComponentSpec]:
components = [
ComponentSpec("vae", AutoencoderKLQwenImage),
]
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
return components
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(self._image_input_name, required=True),
InputParam("generator"),
]
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
return inputs
@property
@@ -974,6 +978,50 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
return components, state
class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
model_name = "qwenimage-edit-plus"
@property
def intermediate_outputs(self) -> List[OutputParam]:
# Each reference image latent can have varied resolutions hence we return this as a list.
return [
OutputParam(
self._image_latents_output_name,
type_hint=List[torch.Tensor],
description="The latents representing the reference image(s).",
)
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = components.vae.dtype
image = getattr(block_state, self._image_input_name)
# Encode image into latents
image_latents = []
for img in image:
image_latents.append(
encode_vae_image(
image=img,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
latent_channels=components.num_channels_latents,
)
)
setattr(block_state, self._image_latents_output_name, image_latents)
self.set_block_state(state, block_state)
return components, state
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
model_name = "qwenimage"

View File

@@ -224,11 +224,7 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
model_name = "qwenimage"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
@@ -372,6 +368,76 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
return components, state
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
model_name = "qwenimage-edit-plus"
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# Each image latent can have different size in QwenImage Edit Plus.
image_heights = []
image_widths = []
packed_image_latent_tensors = []
for img_latent_tensor in image_latent_tensor:
# 1. Calculate height/width from latents
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
image_heights.append(height)
image_widths.append(width)
# 2. Patchify the image latent tensor
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
# 3. Expand batch size
img_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=img_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
packed_image_latent_tensors.append(img_latent_tensor)
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
block_state.image_height = image_heights
block_state.image_width = image_widths
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
block_state.height = block_state.height or image_heights[-1]
block_state.width = block_state.width or image_widths[-1]
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"

View File

@@ -18,6 +18,7 @@ from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageControlNetBeforeDenoiserStep,
QwenImageCreateMaskLatentsStep,
QwenImageEditPlusRoPEInputsStep,
QwenImageEditRoPEInputsStep,
QwenImagePrepareLatentsStep,
QwenImagePrepareLatentsWithStrengthStep,
@@ -40,6 +41,7 @@ from .encoders import (
QwenImageEditPlusProcessImagesInputStep,
QwenImageEditPlusResizeDynamicStep,
QwenImageEditPlusTextEncoderStep,
QwenImageEditPlusVaeEncoderDynamicStep,
QwenImageEditResizeDynamicStep,
QwenImageEditTextEncoderStep,
QwenImageInpaintProcessImagesInputStep,
@@ -47,7 +49,12 @@ from .encoders import (
QwenImageTextEncoderStep,
QwenImageVaeEncoderDynamicStep,
)
from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
from .inputs import (
QwenImageControlNetInputsStep,
QwenImageEditPlusInputsDynamicStep,
QwenImageInputsDynamicStep,
QwenImageTextInputsStep,
)
logger = logging.get_logger(__name__)
@@ -904,13 +911,13 @@ QwenImageEditPlusVaeEncoderBlocks = InsertableDict(
[
("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents
]
)
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
model_name = "qwenimage-edit-plus"
block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
@@ -919,25 +926,62 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
return "Vae encoder step that encode the image inputs into their latent representations."
#### QwenImage Edit Plus input blocks
QwenImageEditPlusInputBlocks = InsertableDict(
[
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
(
"additional_inputs",
QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]),
),
]
)
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = QwenImageEditPlusInputBlocks.values()
block_names = QwenImageEditPlusInputBlocks.keys()
#### QwenImage Edit Plus presets
EDIT_PLUS_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditPlusVLEncoderStep()),
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
("input", QwenImageEditInputStep()),
("input", QwenImageEditPlusInputStep()),
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
("denoise", QwenImageEditDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)
QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
]
)
class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values()
block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
# auto before_denoise step for edit tasks
class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [QwenImageEditBeforeDenoiseStep]
block_classes = [QwenImageEditPlusBeforeDenoiseStep]
block_names = ["edit"]
block_trigger_inputs = ["image_latents"]
@@ -946,7 +990,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
return (
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ "This is an auto pipeline block that works for edit (img2img) task.\n"
+ " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ " - if `image_latents` is not provided, step will be skipped."
)
@@ -955,9 +999,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [
QwenImageEditPlusVaeEncoderStep,
]
block_classes = [QwenImageEditPlusVaeEncoderStep]
block_names = ["edit"]
block_trigger_inputs = ["image"]
@@ -974,10 +1016,25 @@ class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
## 3.3 QwenImage-Edit/auto blocks & presets
class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks):
block_classes = [QwenImageEditPlusInputStep]
block_names = ["edit"]
block_trigger_inputs = ["image_latents"]
@property
def description(self):
return (
"Input step that prepares the inputs for the edit denoising step.\n"
+ " It is an auto pipeline block that works for edit task.\n"
+ " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n"
+ " - if `image_latents` is not provided, step will be skipped."
)
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditAutoInputStep,
QwenImageEditPlusAutoInputStep,
QwenImageEditPlusAutoBeforeDenoiseStep,
QwenImageEditAutoDenoiseStep,
]

View File

@@ -26,6 +26,7 @@ from diffusers.modular_pipelines import (
QwenImageModularPipeline,
)
from ...testing_utils import torch_device
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
@@ -104,6 +105,16 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs
def test_multi_images_as_input(self):
inputs = self.get_dummy_inputs()
image = inputs.pop("image")
inputs["image"] = [image, image]
pipe = self.get_pipeline().to(torch_device)
_ = pipe(
**inputs,
)
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt(self):
super().test_num_images_per_prompt()
@@ -117,4 +128,4 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
super().test_inference_batch_single_identical()
def test_guider_cfg(self):
super().test_guider_cfg(1e-3)
super().test_guider_cfg(1e-6)