mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Qwen Edit Plus modular for multi-image input (#12601)
* try to fix qwen edit plus multi images (modular) * up * up * test * up * up
This commit is contained in:
@@ -610,7 +610,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# for edit, image size can be different from the target size (height/width)
|
||||
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(
|
||||
@@ -640,6 +639,37 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
|
||||
*[
|
||||
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
|
||||
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
|
||||
],
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
## ControlNet inputs for denoiser
|
||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -330,7 +330,7 @@ class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
|
||||
output_name: str = "resized_image",
|
||||
vae_image_output_name: str = "vae_image",
|
||||
):
|
||||
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
|
||||
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.
|
||||
|
||||
This block resizes an input image or a list input images and exposes the resized result under configurable
|
||||
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
|
||||
@@ -809,9 +809,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="processed_image"),
|
||||
]
|
||||
return [OutputParam(name="processed_image")]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
@@ -851,7 +849,10 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
vae_image_size = 1024 * 1024
|
||||
|
||||
def __init__(self):
|
||||
self.vae_image_size = 1024 * 1024
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -868,6 +869,7 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
||||
if block_state.vae_image is None and block_state.image is None:
|
||||
raise ValueError("`vae_image` and `image` cannot be None at the same time")
|
||||
|
||||
vae_image_sizes = None
|
||||
if block_state.vae_image is None:
|
||||
image = block_state.image
|
||||
self.check_inputs(
|
||||
@@ -879,12 +881,19 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
||||
image=image, height=height, width=width
|
||||
)
|
||||
else:
|
||||
width, height = block_state.vae_image[0].size
|
||||
image = block_state.vae_image
|
||||
# QwenImage Edit Plus can allow multiple input images with varied resolutions
|
||||
processed_images = []
|
||||
vae_image_sizes = []
|
||||
for img in block_state.vae_image:
|
||||
width, height = img.size
|
||||
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
|
||||
vae_image_sizes.append((vae_width, vae_height))
|
||||
processed_images.append(
|
||||
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
|
||||
)
|
||||
block_state.processed_image = processed_images
|
||||
|
||||
block_state.processed_image = components.image_processor.preprocess(
|
||||
image=image, height=height, width=width
|
||||
)
|
||||
block_state.vae_image_sizes = vae_image_sizes
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -926,17 +935,12 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
]
|
||||
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(self._image_input_name, required=True),
|
||||
InputParam("generator"),
|
||||
]
|
||||
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
@@ -974,6 +978,50 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
# Each reference image latent can have varied resolutions hence we return this as a list.
|
||||
return [
|
||||
OutputParam(
|
||||
self._image_latents_output_name,
|
||||
type_hint=List[torch.Tensor],
|
||||
description="The latents representing the reference image(s).",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image = getattr(block_state, self._image_input_name)
|
||||
|
||||
# Encode image into latents
|
||||
image_latents = []
|
||||
for img in image:
|
||||
image_latents.append(
|
||||
encode_vae_image(
|
||||
image=img,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
)
|
||||
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
|
||||
@@ -224,11 +224,7 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
@@ -372,6 +368,76 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# Each image latent can have different size in QwenImage Edit Plus.
|
||||
image_heights = []
|
||||
image_widths = []
|
||||
packed_image_latent_tensors = []
|
||||
|
||||
for img_latent_tensor in image_latent_tensor:
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
|
||||
image_heights.append(height)
|
||||
image_widths.append(width)
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
img_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=img_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
packed_image_latent_tensors.append(img_latent_tensor)
|
||||
|
||||
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
|
||||
block_state.image_height = image_heights
|
||||
block_state.image_width = image_widths
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
block_state.height = block_state.height or image_heights[-1]
|
||||
block_state.width = block_state.width or image_widths[-1]
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditPlusRoPEInputsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
@@ -40,6 +41,7 @@ from .encoders import (
|
||||
QwenImageEditPlusProcessImagesInputStep,
|
||||
QwenImageEditPlusResizeDynamicStep,
|
||||
QwenImageEditPlusTextEncoderStep,
|
||||
QwenImageEditPlusVaeEncoderDynamicStep,
|
||||
QwenImageEditResizeDynamicStep,
|
||||
QwenImageEditTextEncoderStep,
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
@@ -47,7 +49,12 @@ from .encoders import (
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
|
||||
from .inputs import (
|
||||
QwenImageControlNetInputsStep,
|
||||
QwenImageEditPlusInputsDynamicStep,
|
||||
QwenImageInputsDynamicStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -904,13 +911,13 @@ QwenImageEditPlusVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
|
||||
("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
|
||||
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||
("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
|
||||
block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
|
||||
|
||||
@@ -919,25 +926,62 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
#### QwenImage Edit Plus input blocks
|
||||
QwenImageEditPlusInputBlocks = InsertableDict(
|
||||
[
|
||||
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||
(
|
||||
"additional_inputs",
|
||||
QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = QwenImageEditPlusInputBlocks.values()
|
||||
block_names = QwenImageEditPlusInputBlocks.keys()
|
||||
|
||||
|
||||
#### QwenImage Edit Plus presets
|
||||
EDIT_PLUS_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditPlusVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
|
||||
("input", QwenImageEditInputStep()),
|
||||
("input", QwenImageEditPlusInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||
("denoise", QwenImageEditDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
|
||||
|
||||
|
||||
# auto before_denoise step for edit tasks
|
||||
class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [QwenImageEditBeforeDenoiseStep]
|
||||
block_classes = [QwenImageEditPlusBeforeDenoiseStep]
|
||||
block_names = ["edit"]
|
||||
block_trigger_inputs = ["image_latents"]
|
||||
|
||||
@@ -946,7 +990,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for edit (img2img) task.\n"
|
||||
+ " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
|
||||
+ " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
|
||||
+ " - if `image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
@@ -955,9 +999,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageEditPlusVaeEncoderStep,
|
||||
]
|
||||
block_classes = [QwenImageEditPlusVaeEncoderStep]
|
||||
block_names = ["edit"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@@ -974,10 +1016,25 @@ class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
## 3.3 QwenImage-Edit/auto blocks & presets
|
||||
|
||||
|
||||
class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditPlusInputStep]
|
||||
block_names = ["edit"]
|
||||
block_trigger_inputs = ["image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit denoising step.\n"
|
||||
+ " It is an auto pipeline block that works for edit task.\n"
|
||||
+ " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n"
|
||||
+ " - if `image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditAutoInputStep,
|
||||
QwenImageEditPlusAutoInputStep,
|
||||
QwenImageEditPlusAutoBeforeDenoiseStep,
|
||||
QwenImageEditAutoDenoiseStep,
|
||||
]
|
||||
|
||||
@@ -26,6 +26,7 @@ from diffusers.modular_pipelines import (
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
|
||||
|
||||
|
||||
@@ -104,6 +105,16 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
||||
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
|
||||
return inputs
|
||||
|
||||
def test_multi_images_as_input(self):
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = inputs.pop("image")
|
||||
inputs["image"] = [image, image]
|
||||
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
_ = pipe(
|
||||
**inputs,
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
|
||||
def test_num_images_per_prompt(self):
|
||||
super().test_num_images_per_prompt()
|
||||
@@ -117,4 +128,4 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
||||
super().test_inference_batch_single_identical()
|
||||
|
||||
def test_guider_cfg(self):
|
||||
super().test_guider_cfg(1e-3)
|
||||
super().test_guider_cfg(1e-6)
|
||||
|
||||
Reference in New Issue
Block a user