1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Modular] Qwen (#12220)

* add qwen modular
This commit is contained in:
YiYi Xu
2025-09-08 00:27:02 -10:00
committed by GitHub
parent fc337d5853
commit f50b18eec7
17 changed files with 4275 additions and 9 deletions

View File

@@ -20,6 +20,12 @@ All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or Nu
[[autodoc]] image_processor.VaeImageProcessor
## InpaintProcessor
The [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay.
[[autodoc]] image_processor.InpaintProcessor
## VaeImageProcessorLDM3D
The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.

View File

@@ -385,6 +385,10 @@ else:
[
"FluxAutoBlocks",
"FluxModularPipeline",
"QwenImageAutoBlocks",
"QwenImageEditAutoBlocks",
"QwenImageEditModularPipeline",
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"WanAutoBlocks",
@@ -1038,6 +1042,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modular_pipelines import (
FluxAutoBlocks,
FluxModularPipeline,
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
WanAutoBlocks,

View File

@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
# AttnProcessor2_0
@@ -140,6 +141,14 @@ def _register_attention_processors_metadata():
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
)
# QwenDoubleStreamAttnProcessor2
AttentionProcessorRegistry.register(
model_class=QwenDoubleStreamAttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -298,4 +307,5 @@ _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___h
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
# fmt: on

View File

@@ -523,6 +523,7 @@ class VaeImageProcessor(ConfigMixin):
size=(height, width),
)
image = self.pt_to_numpy(image)
return image
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
@@ -838,6 +839,137 @@ class VaeImageProcessor(ConfigMixin):
return image
class InpaintProcessor(ConfigMixin):
"""
Image processor for inpainting image and mask.
"""
config_name = CONFIG_NAME
@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
vae_latent_channels: int = 4,
resample: str = "lanczos",
reducing_gap: int = None,
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_grayscale: bool = False,
mask_do_normalize: bool = False,
mask_do_binarize: bool = True,
mask_do_convert_grayscale: bool = True,
):
super().__init__()
self._image_processor = VaeImageProcessor(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
vae_latent_channels=vae_latent_channels,
resample=resample,
reducing_gap=reducing_gap,
do_normalize=do_normalize,
do_binarize=do_binarize,
do_convert_grayscale=do_convert_grayscale,
)
self._mask_processor = VaeImageProcessor(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
vae_latent_channels=vae_latent_channels,
resample=resample,
reducing_gap=reducing_gap,
do_normalize=mask_do_normalize,
do_binarize=mask_do_binarize,
do_convert_grayscale=mask_do_convert_grayscale,
)
def preprocess(
self,
image: PIL.Image.Image,
mask: PIL.Image.Image = None,
height: int = None,
width: int = None,
padding_mask_crop: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess the image and mask.
"""
if mask is None and padding_mask_crop is not None:
raise ValueError("mask must be provided if padding_mask_crop is provided")
# if mask is None, same behavior as regular image processor
if mask is None:
return self._image_processor.preprocess(image, height=height, width=width)
if padding_mask_crop is not None:
crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
processed_image = self._image_processor.preprocess(
image,
height=height,
width=width,
crops_coords=crops_coords,
resize_mode=resize_mode,
)
processed_mask = self._mask_processor.preprocess(
mask,
height=height,
width=width,
resize_mode=resize_mode,
crops_coords=crops_coords,
)
if crops_coords is not None:
postprocessing_kwargs = {
"crops_coords": crops_coords,
"original_image": image,
"original_mask": mask,
}
else:
postprocessing_kwargs = {
"crops_coords": None,
"original_image": None,
"original_mask": None,
}
return processed_image, processed_mask, postprocessing_kwargs
def postprocess(
self,
image: torch.Tensor,
output_type: str = "pil",
original_image: Optional[PIL.Image.Image] = None,
original_mask: Optional[PIL.Image.Image] = None,
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
"""
Postprocess the image, optionally apply mask overlay
"""
image = self._image_processor.postprocess(
image,
output_type=output_type,
)
# optionally apply the mask overlay
if crops_coords is not None and (original_image is None or original_mask is None):
raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
elif crops_coords is not None and output_type != "pil":
raise ValueError("output_type must be 'pil' if crops_coords is provided")
elif crops_coords is not None:
image = [
self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
]
return image
class VaeImageProcessorLDM3D(VaeImageProcessor):
"""
Image processor for VAE LDM3D.

View File

@@ -47,6 +47,12 @@ else:
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
"QwenImageModularPipeline",
"QwenImageEditModularPipeline",
"QwenImageEditAutoBlocks",
]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -68,6 +74,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
from .qwenimage import (
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import WanAutoBlocks, WanModularPipeline
else:

View File

@@ -56,6 +56,8 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
]
)
@@ -64,6 +66,8 @@ MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
("WanModularPipeline", "WanAutoBlocks"),
("FluxModularPipeline", "FluxAutoBlocks"),
("QwenImageModularPipeline", "QwenImageAutoBlocks"),
("QwenImageEditModularPipeline", "QwenImageEditAutoBlocks"),
]
)
@@ -133,8 +137,8 @@ class PipelineState:
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
intermediates dict.
"""
if name in self.intermediates:
return self.intermediates[name]
if name in self.values:
return self.values[name]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __repr__(self):
@@ -548,8 +552,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
def __init__(self):
sub_blocks = InsertableDict()
for block_name, block_cls in zip(self.block_names, self.block_classes):
sub_blocks[block_name] = block_cls()
for block_name, block in zip(self.block_names, self.block_classes):
if inspect.isclass(block):
sub_blocks[block_name] = block()
else:
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(
@@ -830,7 +837,9 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
return expected_configs
@classmethod
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks":
def from_blocks_dict(
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
) -> "SequentialPipelineBlocks":
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
Args:
@@ -852,12 +861,19 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
instance.block_classes = [block.__class__ for block in sub_blocks.values()]
instance.block_names = list(sub_blocks.keys())
instance.sub_blocks = sub_blocks
if description is not None:
instance.description = description
return instance
def __init__(self):
sub_blocks = InsertableDict()
for block_name, block_cls in zip(self.block_names, self.block_classes):
sub_blocks[block_name] = block_cls()
for block_name, block in zip(self.block_names, self.block_classes):
if inspect.isclass(block):
sub_blocks[block_name] = block()
else:
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
def _get_inputs(self):
@@ -1280,8 +1296,11 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def __init__(self):
sub_blocks = InsertableDict()
for block_name, block_cls in zip(self.block_names, self.block_classes):
sub_blocks[block_name] = block_cls()
for block_name, block in zip(self.block_names, self.block_classes):
if inspect.isclass(block):
sub_blocks[block_name] = block()
else:
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
@classmethod

View File

@@ -0,0 +1,75 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["QwenImageTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"CONTROLNET_BLOCKS",
"EDIT_AUTO_BLOCKS",
"EDIT_BLOCKS",
"EDIT_INPAINT_BLOCKS",
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"QwenImageAutoBlocks",
"QwenImageEditAutoBlocks",
]
_import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import (
QwenImageTextEncoderStep,
)
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
CONTROLNET_BLOCKS,
EDIT_AUTO_BLOCKS,
EDIT_BLOCKS,
EDIT_INPAINT_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
TEXT2IMAGE_BLOCKS,
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
)
from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,727 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils.torch_utils import randn_tensor, unwrap_module
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(scheduler, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
if hasattr(scheduler, "set_begin_index"):
scheduler.set_begin_index(t_start * scheduler.order)
return timesteps, num_inference_steps - t_start
# Prepare Latents steps
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Prepare initial random noise for the generation process"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="generator"),
InputParam(
name="batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam(
name="dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs, can be generated in input step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process",
),
]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
if height is not None and height % (vae_scale_factor * 2) != 0:
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
if width is not None and width % (vae_scale_factor * 2) != 0:
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
height=block_state.height,
width=block_state.width,
vae_scale_factor=components.vae_scale_factor,
)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
# we can update the height and width here since it's used to generate the initial
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width)
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
self.set_block_state(state, block_state)
return components, state
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The initial random noised, can be generated in prepare latent step.",
),
InputParam(
name="image_latents",
required=True,
type_hint=torch.Tensor,
description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
),
InputParam(
name="timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="initial_noise",
type_hint=torch.Tensor,
description="The initial random noised used for inpainting denoising.",
),
]
@staticmethod
def check_inputs(image_latents, latents):
if image_latents.shape[0] != latents.shape[0]:
raise ValueError(
f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
)
if image_latents.ndim != 3:
raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
image_latents=block_state.image_latents,
latents=block_state.latents,
)
# prepare latent timestep
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
# make copy of initial_noise
block_state.initial_noise = block_state.latents
# scale noise
block_state.latents = components.scheduler.scale_noise(
block_state.image_latents, latent_timestep, block_state.latents
)
self.set_block_state(state, block_state)
return components, state
class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
name="processed_mask_image",
required=True,
type_hint=torch.Tensor,
description="The processed mask to use for the inpainting process.",
),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="dtype", required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
block_state.mask = torch.nn.functional.interpolate(
block_state.processed_mask_image,
size=(height_latents, width_latents),
)
block_state.mask = block_state.mask.unsqueeze(2)
block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1)
block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype)
block_state.mask = components.pachifier.pack_latents(block_state.mask)
self.set_block_state(state, block_state)
return components, state
# Set Timesteps steps
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_inference_steps", default=50),
InputParam(name="sigmas"),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process, used to calculate the image sequence length.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
sigmas = (
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
if block_state.sigmas is None
else block_state.sigmas
)
mu = calculate_shift(
image_seq_len=block_state.latents.shape[1],
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
base_shift=components.scheduler.config.get("base_shift", 0.5),
max_shift=components.scheduler.config.get("max_shift", 1.15),
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
scheduler=components.scheduler,
num_inference_steps=block_state.num_inference_steps,
device=device,
sigmas=sigmas,
mu=mu,
)
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
return components, state
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_inference_steps", default=50),
InputParam(name="sigmas"),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process, used to calculate the image sequence length.",
),
InputParam(name="strength", default=0.9),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="timesteps",
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
sigmas = (
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
if block_state.sigmas is None
else block_state.sigmas
)
mu = calculate_shift(
image_seq_len=block_state.latents.shape[1],
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
base_shift=components.scheduler.config.get("base_shift", 0.5),
max_shift=components.scheduler.config.get("max_shift", 1.15),
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
scheduler=components.scheduler,
num_inference_steps=block_state.num_inference_steps,
device=device,
sigmas=sigmas,
mu=mu,
)
block_state.timesteps, block_state.num_inference_steps = get_timesteps(
scheduler=components.scheduler,
num_inference_steps=block_state.num_inference_steps,
strength=block_state.strength,
)
self.set_block_state(state, block_state)
return components, state
# other inputs for denoiser
## RoPE inputs for denoiser
class QwenImageRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.img_shapes = [
[
(
1,
block_state.height // components.vae_scale_factor // 2,
block_state.width // components.vae_scale_factor // 2,
)
]
* block_state.batch_size
]
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
self.set_block_state(state, block_state)
return components, state
class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(
name="resized_image", required=True, type_hint=torch.Tensor, description="The resized image input"
),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# for edit, image size can be different from the target size (height/width)
image = (
block_state.resized_image[0] if isinstance(block_state.resized_image, list) else block_state.resized_image
)
image_width, image_height = image.size
block_state.img_shapes = [
[
(
1,
block_state.height // components.vae_scale_factor // 2,
block_state.width // components.vae_scale_factor // 2,
),
(1, image_height // components.vae_scale_factor // 2, image_width // components.vae_scale_factor // 2),
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
self.set_block_state(state, block_state)
return components, state
## ControlNet inputs for denoiser
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("controlnet", QwenImageControlNetModel),
]
@property
def description(self) -> str:
return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("control_guidance_start", default=0.0),
InputParam("control_guidance_end", default=1.0),
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("control_image_latents", required=True),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
controlnet = unwrap_module(components.controlnet)
# control_guidance_start/control_guidance_end (align format)
if not isinstance(block_state.control_guidance_start, list) and isinstance(
block_state.control_guidance_end, list
):
block_state.control_guidance_start = len(block_state.control_guidance_end) * [
block_state.control_guidance_start
]
elif not isinstance(block_state.control_guidance_end, list) and isinstance(
block_state.control_guidance_start, list
):
block_state.control_guidance_end = len(block_state.control_guidance_start) * [
block_state.control_guidance_end
]
elif not isinstance(block_state.control_guidance_start, list) and not isinstance(
block_state.control_guidance_end, list
):
mult = (
len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1
)
block_state.control_guidance_start, block_state.control_guidance_end = (
mult * [block_state.control_guidance_start],
mult * [block_state.control_guidance_end],
)
# controlnet_conditioning_scale (align format)
if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance(
block_state.controlnet_conditioning_scale, float
):
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult
# controlnet_keep
block_state.controlnet_keep = []
for i in range(len(block_state.timesteps)):
keeps = [
1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
]
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,203 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...image_processor import InpaintProcessor, VaeImageProcessor
from ...models import AutoencoderKLQwenImage
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
logger = logging.get_logger(__name__)
class QwenImageDecoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Step that decodes the latents to images"
@property
def expected_components(self) -> List[ComponentSpec]:
components = [
ComponentSpec("vae", AutoencoderKLQwenImage),
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
return components
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to decode, can be generated in the denoise step",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
block_state.latents = components.pachifier.unpack_latents(
block_state.latents, block_state.height, block_state.width
)
block_state.latents = block_state.latents.to(components.vae.dtype)
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(block_state.latents.device, block_state.latents.dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(block_state.latents.device, block_state.latents.dtype)
block_state.latents = block_state.latents / latents_std + latents_mean
block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0]
self.set_block_state(state, block_state)
return components, state
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "postprocess the generated image"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("images", required=True, description="the generated image from decoders step"),
InputParam(
name="output_type",
default="pil",
type_hint=str,
description="The type of the output images, can be 'pil', 'np', 'pt'",
),
]
@staticmethod
def check_inputs(output_type):
if output_type not in ["pil", "np", "pt"]:
raise ValueError(f"Invalid output_type: {output_type}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
self.check_inputs(block_state.output_type)
block_state.images = components.image_processor.postprocess(
image=block_state.images,
output_type=block_state.output_type,
)
self.set_block_state(state, block_state)
return components, state
class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "postprocess the generated image, optional apply the mask overally to the original image.."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_mask_processor",
InpaintProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("images", required=True, description="the generated image from decoders step"),
InputParam(
name="output_type",
default="pil",
type_hint=str,
description="The type of the output images, can be 'pil', 'np', 'pt'",
),
InputParam("mask_overlay_kwargs"),
]
@staticmethod
def check_inputs(output_type, mask_overlay_kwargs):
if output_type not in ["pil", "np", "pt"]:
raise ValueError(f"Invalid output_type: {output_type}")
if mask_overlay_kwargs and output_type != "pil":
raise ValueError("only support output_type 'pil' for mask overlay")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs)
if block_state.mask_overlay_kwargs is None:
mask_overlay_kwargs = {}
else:
mask_overlay_kwargs = block_state.mask_overlay_kwargs
block_state.images = components.image_mask_processor.postprocess(
image=block_state.images,
**mask_overlay_kwargs,
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,668 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline
logger = logging.get_logger(__name__)
class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# one timestep
block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
block_state.latent_model_input = block_state.latents
return components, block_state
class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"image_latents",
required=True,
type_hint=torch.Tensor,
description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# one timestep
block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1)
block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
return components, block_state
class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
ComponentSpec("controlnet", QwenImageControlNetModel),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that runs the controlnet before the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"control_image_latents",
required=True,
type_hint=torch.Tensor,
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"controlnet_conditioning_scale",
type_hint=float,
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"controlnet_keep",
required=True,
type_hint=List[float],
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs for the denoiser. "
"It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
),
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int):
# cond_scale for the timestep (controlnet input)
if isinstance(block_state.controlnet_keep[i], list):
block_state.cond_scale = [
c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])
]
else:
controlnet_cond_scale = block_state.controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
# run controlnet for the guidance batch
controlnet_block_samples = components.controlnet(
hidden_states=block_state.latent_model_input,
controlnet_cond=block_state.control_image_latents,
conditioning_scale=block_state.cond_scale,
timestep=block_state.timestep / 1000,
img_shapes=block_state.img_shapes,
encoder_hidden_states=block_state.prompt_embeds,
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
txt_seq_lens=block_state.txt_seq_lens,
return_dict=False,
)
block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples
return components, block_state
class QwenImageLoopDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"step within the denoising loop that denoise the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", QwenImageTransformer2DModel),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
InputParam(
"img_shapes",
required=True,
type_hint=List[Tuple[int, int]],
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
guider_input_fields = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
# YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input,
timestep=block_state.timestep / 1000,
img_shapes=block_state.img_shapes,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
**block_state.additional_cond_kwargs,
)[0]
components.guider.cleanup_models(components.transformer)
guider_output = components.guider(guider_state)
# apply guidance rescale
pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True)
pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True)
block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm)
return components, block_state
class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"step within the denoising loop that denoise the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", QwenImageTransformer2DModel),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
InputParam(
"img_shapes",
required=True,
type_hint=List[Tuple[int, int]],
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
guider_input_fields = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
# YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input,
timestep=block_state.timestep / 1000,
img_shapes=block_state.img_shapes,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
**block_state.additional_cond_kwargs,
)[0]
components.guider.cleanup_models(components.transformer)
guider_output = components.guider(guider_state)
pred = guider_output.pred[:, : block_state.latents.size(1)]
pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)]
# apply guidance rescale
pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True)
pred_norm = torch.norm(pred, dim=-1, keepdim=True)
block_state.noise_pred = pred * (pred_cond_norm / pred_norm)
return components, block_state
class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"step within the denoising loop that updates the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"step within the denoising loop that updates the latents using mask and image_latents for inpainting. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"mask",
required=True,
type_hint=torch.Tensor,
description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
),
InputParam(
"image_latents",
required=True,
type_hint=torch.Tensor,
description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
),
InputParam(
"initial_noise",
required=True,
type_hint=torch.Tensor,
description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.init_latents_proper = block_state.image_latents
if i < len(block_state.timesteps) - 1:
block_state.noise_timestep = block_state.timesteps[i + 1]
block_state.init_latents_proper = components.scheduler.scale_noise(
block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise
)
block_state.latents = (
1 - block_state.mask
) * block_state.init_latents_proper + block_state.mask * block_state.latents
return components, block_state
class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
block_state.additional_cond_kwargs = {}
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
# composing the denoising loops
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopDenoiser,
QwenImageLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `QwenImageLoopBeforeDenoiser`\n"
" - `QwenImageLoopDenoiser`\n"
" - `QwenImageLoopAfterDenoiser`\n"
"This block supports text2image and image2image tasks for QwenImage."
)
# composing the inpainting denoising loops
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopDenoiser,
QwenImageLoopAfterDenoiser,
QwenImageLoopAfterDenoiserInpaint,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `QwenImageLoopBeforeDenoiser`\n"
" - `QwenImageLoopDenoiser`\n"
" - `QwenImageLoopAfterDenoiser`\n"
" - `QwenImageLoopAfterDenoiserInpaint`\n"
"This block supports inpainting tasks for QwenImage."
)
# composing the controlnet denoising loops
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopBeforeDenoiserControlNet,
QwenImageLoopDenoiser,
QwenImageLoopAfterDenoiser,
]
block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `QwenImageLoopBeforeDenoiser`\n"
" - `QwenImageLoopBeforeDenoiserControlNet`\n"
" - `QwenImageLoopDenoiser`\n"
" - `QwenImageLoopAfterDenoiser`\n"
"This block supports text2img/img2img tasks with controlnet for QwenImage."
)
# composing the controlnet denoising loops
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopBeforeDenoiserControlNet,
QwenImageLoopDenoiser,
QwenImageLoopAfterDenoiser,
QwenImageLoopAfterDenoiserInpaint,
]
block_names = [
"before_denoiser",
"before_denoiser_controlnet",
"denoiser",
"after_denoiser",
"after_denoiser_inpaint",
]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `QwenImageLoopBeforeDenoiser`\n"
" - `QwenImageLoopBeforeDenoiserControlNet`\n"
" - `QwenImageLoopDenoiser`\n"
" - `QwenImageLoopAfterDenoiser`\n"
" - `QwenImageLoopAfterDenoiserInpaint`\n"
"This block supports inpainting tasks with controlnet for QwenImage."
)
# composing the denoising loops
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
block_classes = [
QwenImageEditLoopBeforeDenoiser,
QwenImageEditLoopDenoiser,
QwenImageLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `QwenImageEditLoopBeforeDenoiser`\n"
" - `QwenImageEditLoopDenoiser`\n"
" - `QwenImageLoopAfterDenoiser`\n"
"This block supports QwenImage Edit."
)
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
block_classes = [
QwenImageEditLoopBeforeDenoiser,
QwenImageEditLoopDenoiser,
QwenImageLoopAfterDenoiser,
QwenImageLoopAfterDenoiserInpaint,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `QwenImageEditLoopBeforeDenoiser`\n"
" - `QwenImageEditLoopDenoiser`\n"
" - `QwenImageLoopAfterDenoiser`\n"
" - `QwenImageLoopAfterDenoiserInpaint`\n"
"This block supports inpainting tasks for QwenImage Edit."
)

View File

@@ -0,0 +1,857 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
import PIL
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist
from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel
from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
from ...utils import logging
from ...utils.torch_utils import unwrap_module
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline
logger = logging.get_logger(__name__)
def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result
def get_qwen_prompt_embeds(
text_encoder,
tokenizer,
prompt: Union[str, List[str]] = None,
prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
prompt_template_encode_start_idx: int = 34,
tokenizer_max_length: int = 1024,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
template = prompt_template_encode
drop_idx = prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = tokenizer(
txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds = prompt_embeds.to(device=device)
return prompt_embeds, encoder_attention_mask
def get_qwen_prompt_embeds_edit(
text_encoder,
processor,
prompt: Union[str, List[str]] = None,
image: Optional[torch.Tensor] = None,
prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
prompt_template_encode_start_idx: int = 64,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
template = prompt_template_encode
drop_idx = prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
model_inputs = processor(
text=txt,
images=image,
padding=True,
return_tensors="pt",
).to(device)
outputs = text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
)
hidden_states = outputs.hidden_states[-1]
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds = prompt_embeds.to(device=device)
return prompt_embeds, encoder_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image
def encode_vae_image(
image: torch.Tensor,
vae: AutoencoderKLQwenImage,
generator: torch.Generator,
device: torch.device,
dtype: torch.dtype,
latent_channels: int = 16,
sample_mode: str = "argmax",
):
if not isinstance(image, torch.Tensor):
raise ValueError(f"Expected image to be a tensor, got {type(image)}.")
# preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
if image.dim() == 4:
image = image.unsqueeze(2)
elif image.dim() != 5:
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, latent_channels, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
latents_std = (
torch.tensor(vae.config.latents_std)
.view(1, latent_channels, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
image_latents = (image_latents - latents_mean) / latents_std
return image_latents
class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
model_name = "qwenimage"
def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
This block resizes an input image tensor and exposes the resized result under configurable input and output
names. Use this when you need to wire the resize step to different image fields (e.g., "image",
"control_image")
Args:
input_name (str, optional): Name of the image field to read from the
pipeline state. Defaults to "image".
output_name (str, optional): Name of the resized image field to write
back to the pipeline state. Defaults to "resized_image".
"""
if not isinstance(input_name, str) or not isinstance(output_name, str):
raise ValueError(
f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
)
self._image_input_name = input_name
self._resized_image_output_name = output_name
super().__init__()
@property
def description(self) -> str:
return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_resize_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
images = getattr(block_state, self._image_input_name)
if not is_valid_image_imagelist(images):
raise ValueError(f"Images must be image or list of images but are {type(images)}")
if is_valid_image(images):
images = [images]
image_width, image_height = images[0].size
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
resized_images = [
components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
for image in images
]
setattr(block_state, self._resized_image_output_name, resized_images)
self.set_block_state(state, block_state)
return components, state
class QwenImageTextEncoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"),
ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(
name="prompt_template_encode",
default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
),
ConfigSpec(name="prompt_template_encode_start_idx", default=34),
ConfigSpec(name="tokenizer_max_length", default=1024),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
InputParam(
name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The prompt embeddings",
),
OutputParam(
name="prompt_embeds_mask",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The encoder attention mask",
),
OutputParam(
name="negative_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The negative prompt embeddings",
),
OutputParam(
name="negative_prompt_embeds_mask",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The negative prompt embeddings mask",
),
]
@staticmethod
def check_inputs(prompt, negative_prompt, max_sequence_length):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if (
negative_prompt is not None
and not isinstance(negative_prompt, str)
and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
device = components._execution_device
self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length)
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds(
components.text_encoder,
components.tokenizer,
prompt=block_state.prompt,
prompt_template_encode=components.config.prompt_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
tokenizer_max_length=components.config.tokenizer_max_length,
device=device,
)
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or ""
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
components.text_encoder,
components.tokenizer,
prompt=negative_prompt,
prompt_template_encode=components.config.prompt_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
tokenizer_max_length=components.config.tokenizer_max_length,
device=device,
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[
:, : block_state.max_sequence_length
]
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[
:, : block_state.max_sequence_length
]
self.set_block_state(state, block_state)
return components, state
class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
ComponentSpec("processor", Qwen2VLProcessor),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(
name="prompt_template_encode",
default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
),
ConfigSpec(name="prompt_template_encode_start_idx", default=64),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
InputParam(
name="resized_image",
required=True,
type_hint=torch.Tensor,
description="The image prompt to encode, should be resized using resize step",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The prompt embeddings",
),
OutputParam(
name="prompt_embeds_mask",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The encoder attention mask",
),
OutputParam(
name="negative_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The negative prompt embeddings",
),
OutputParam(
name="negative_prompt_embeds_mask",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="The negative prompt embeddings mask",
),
]
@staticmethod
def check_inputs(prompt, negative_prompt):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if (
negative_prompt is not None
and not isinstance(negative_prompt, str)
and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
self.check_inputs(block_state.prompt, block_state.negative_prompt)
device = components._execution_device
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit(
components.text_encoder,
components.processor,
prompt=block_state.prompt,
image=block_state.resized_image,
prompt_template_encode=components.config.prompt_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
device=device,
)
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or ""
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
components.text_encoder,
components.processor,
prompt=negative_prompt,
image=block_state.resized_image,
prompt_template_encode=components.config.prompt_template_encode,
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
device=device,
)
self.set_block_state(state, block_state)
return components, state
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_mask_processor",
InpaintProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("mask_image", required=True),
InputParam("resized_image"),
InputParam("image"),
InputParam("height"),
InputParam("width"),
InputParam("padding_mask_crop"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="processed_image"),
OutputParam(name="processed_mask_image"),
OutputParam(
name="mask_overlay_kwargs",
type_hint=Dict,
description="The kwargs for the postprocess step to apply the mask overlay",
),
]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
if height is not None and height % (vae_scale_factor * 2) != 0:
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
if width is not None and width % (vae_scale_factor * 2) != 0:
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
if block_state.resized_image is None and block_state.image is None:
raise ValueError("resized_image and image cannot be None at the same time")
if block_state.resized_image is None:
image = block_state.image
self.check_inputs(
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
)
height = block_state.height or components.default_height
width = block_state.width or components.default_width
else:
width, height = block_state.resized_image[0].size
image = block_state.resized_image
block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = (
components.image_mask_processor.preprocess(
image=image,
mask=block_state.mask_image,
height=height,
width=width,
padding_mask_crop=block_state.padding_mask_crop,
)
)
self.set_block_state(state, block_state)
return components, state
class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("resized_image"),
InputParam("image"),
InputParam("height"),
InputParam("width"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="processed_image"),
]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
if height is not None and height % (vae_scale_factor * 2) != 0:
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
if width is not None and width % (vae_scale_factor * 2) != 0:
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
if block_state.resized_image is None and block_state.image is None:
raise ValueError("resized_image and image cannot be None at the same time")
if block_state.resized_image is None:
image = block_state.image
self.check_inputs(
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
)
height = block_state.height or components.default_height
width = block_state.width or components.default_width
else:
width, height = block_state.resized_image[0].size
image = block_state.resized_image
block_state.processed_image = components.image_processor.preprocess(
image=image,
height=height,
width=width,
)
self.set_block_state(state, block_state)
return components, state
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
model_name = "qwenimage"
def __init__(
self,
input_name: str = "processed_image",
output_name: str = "image_latents",
):
"""Initialize a VAE encoder step for converting images to latent representations.
Both the input and output names are configurable so this block can be configured to process to different image
inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
Args:
input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
Examples: "processed_image" or "processed_control_image"
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
Examples: "image_latents" or "control_image_latents"
Examples:
# Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
# Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
input_name="processed_control_image", output_name="control_image_latents"
)
"""
self._image_input_name = input_name
self._image_latents_output_name = output_name
super().__init__()
@property
def description(self) -> str:
return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
@property
def expected_components(self) -> List[ComponentSpec]:
components = [
ComponentSpec("vae", AutoencoderKLQwenImage),
]
return components
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(self._image_input_name, required=True),
InputParam("generator"),
]
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
self._image_latents_output_name,
type_hint=torch.Tensor,
description="The latents representing the reference image",
)
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = components.vae.dtype
image = getattr(block_state, self._image_input_name)
# Encode image into latents
image_latents = encode_vae_image(
image=image,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
latent_channels=components.num_channels_latents,
)
setattr(block_state, self._image_latents_output_name, image_latents)
self.set_block_state(state, block_state)
return components, state
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n"
@property
def expected_components(self) -> List[ComponentSpec]:
components = [
ComponentSpec("vae", AutoencoderKLQwenImage),
ComponentSpec("controlnet", QwenImageControlNetModel),
ComponentSpec(
"control_image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
return components
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam("control_image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("generator"),
]
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"control_image_latents",
type_hint=torch.Tensor,
description="The latents representing the control image",
)
]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
if height is not None and height % (vae_scale_factor * 2) != 0:
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
if width is not None and width % (vae_scale_factor * 2) != 0:
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor)
device = components._execution_device
dtype = components.vae.dtype
height = block_state.height or components.default_height
width = block_state.width or components.default_width
controlnet = unwrap_module(components.controlnet)
if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list):
block_state.control_image = [block_state.control_image]
if isinstance(controlnet, QwenImageMultiControlNetModel):
block_state.control_image_latents = []
for control_image_ in block_state.control_image:
control_image_ = components.control_image_processor.preprocess(
image=control_image_,
height=height,
width=width,
)
control_image_latents_ = encode_vae_image(
image=control_image_,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
latent_channels=components.num_channels_latents,
sample_mode="sample",
)
block_state.control_image_latents.append(control_image_latents_)
elif isinstance(controlnet, QwenImageControlNetModel):
control_image = components.control_image_processor.preprocess(
image=block_state.control_image,
height=height,
width=width,
)
block_state.control_image_latents = encode_vae_image(
image=control_image,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
latent_channels=components.num_channels_latents,
sample_mode="sample",
)
else:
raise ValueError(
f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}"
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,431 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import torch
from ...models import QwenImageMultiControlNetModel
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
def repeat_tensor_to_batch_size(
input_name: str,
input_tensor: torch.Tensor,
batch_size: int,
num_images_per_prompt: int = 1,
) -> torch.Tensor:
"""Repeat tensor elements to match the final batch size.
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
by repeating each element along dimension 0.
The input tensor must have batch size 1 or batch_size. The function will:
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
- If batch size equals batch_size: repeat each element num_images_per_prompt times
Args:
input_name (str): Name of the input tensor (used for error messages)
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
batch_size (int): The base batch size (number of prompts)
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
Returns:
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
Raises:
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
Examples:
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
[4, 3]
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
- shape: [4, 3]
"""
# make sure input is a tensor
if not isinstance(input_tensor, torch.Tensor):
raise ValueError(f"`{input_name}` must be a tensor")
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
if input_tensor.shape[0] == 1:
repeat_by = batch_size * num_images_per_prompt
elif input_tensor.shape[0] == batch_size:
repeat_by = num_images_per_prompt
else:
raise ValueError(
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
)
# expand the tensor to match the batch_size * num_images_per_prompt
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
return input_tensor
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]:
"""Calculate image dimensions from latent tensor dimensions.
This function converts latent space dimensions to image space dimensions by multiplying the latent height and width
by the VAE scale factor.
Args:
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
vae_scale_factor (int): The scale factor used by the VAE to compress images.
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
Returns:
Tuple[int, int]: The calculated image dimensions as (height, width)
Raises:
ValueError: If latents tensor doesn't have 4 or 5 dimensions
"""
# make sure the latents are not packed
if latents.ndim != 4 and latents.ndim != 5:
raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")
latent_height, latent_width = latents.shape[-2:]
height = latent_height * vae_scale_factor
width = latent_width * vae_scale_factor
return height, width
class QwenImageTextInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
summary_section = (
"Text input processing step that standardizes text embeddings for the pipeline.\n"
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
# Placement guidance
placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps."
return summary_section + placement_section
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
]
@staticmethod
def check_inputs(
prompt_embeds,
prompt_embeds_mask,
negative_prompt_embeds,
negative_prompt_embeds_mask,
):
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None")
if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None:
raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`")
if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]:
raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]:
raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`")
elif (
negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]
):
raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
prompt_embeds=block_state.prompt_embeds,
prompt_embeds_mask=block_state.prompt_embeds_mask,
negative_prompt_embeds=block_state.negative_prompt_embeds,
negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask,
)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len
)
self.set_block_state(state, block_state)
return components, state
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
model_name = "qwenimage"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
This is a dynamic block that allows you to configure which inputs to process.
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to []. Examples: ["processed_mask_image"]
Examples:
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
# Configure to process multiple image latent inputs
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
)
"""
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
# Functionality section
summary_section = (
"Input processing step that:\n"
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
# Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
# Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
]
# Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
# Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate height/width from latents
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# 2. Patchify the image latent tensor
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
# 3. Expand batch size
image_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=image_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="control_image_latents", required=True),
InputParam(name="batch_size", required=True),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="height"),
InputParam(name="width"),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if isinstance(components.controlnet, QwenImageMultiControlNetModel):
control_image_latents = []
# loop through each control_image_latents
for i, control_image_latents_ in enumerate(block_state.control_image_latents):
# 1. update height/width if not provided
height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# 2. pack
control_image_latents_ = components.pachifier.pack_latents(control_image_latents_)
# 3. repeat to match the batch size
control_image_latents_ = repeat_tensor_to_batch_size(
input_name=f"control_image_latents[{i}]",
input_tensor=control_image_latents_,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
control_image_latents.append(control_image_latents_)
block_state.control_image_latents = control_image_latents
else:
# 1. update height/width if not provided
height, width = calculate_dimension_from_latents(
block_state.control_image_latents, components.vae_scale_factor
)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# 2. pack
block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents)
# 3. repeat to match the batch size
block_state.control_image_latents = repeat_tensor_to_batch_size(
input_name="control_image_latents",
input_tensor=block_state.control_image_latents,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
block_state.control_image_latents = block_state.control_image_latents
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,841 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageControlNetBeforeDenoiserStep,
QwenImageCreateMaskLatentsStep,
QwenImageEditRoPEInputsStep,
QwenImagePrepareLatentsStep,
QwenImagePrepareLatentsWithStrengthStep,
QwenImageRoPEInputsStep,
QwenImageSetTimestepsStep,
QwenImageSetTimestepsWithStrengthStep,
)
from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
from .denoise import (
QwenImageControlNetDenoiseStep,
QwenImageDenoiseStep,
QwenImageEditDenoiseStep,
QwenImageEditInpaintDenoiseStep,
QwenImageInpaintControlNetDenoiseStep,
QwenImageInpaintDenoiseStep,
QwenImageLoopBeforeDenoiserControlNet,
)
from .encoders import (
QwenImageControlNetVaeEncoderStep,
QwenImageEditResizeDynamicStep,
QwenImageEditTextEncoderStep,
QwenImageInpaintProcessImagesInputStep,
QwenImageProcessImagesInputStep,
QwenImageTextEncoderStep,
QwenImageVaeEncoderDynamicStep,
)
from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
logger = logging.get_logger(__name__)
# 1. QwenImage
## 1.1 QwenImage/text2image
#### QwenImage/decode
#### (standard decode step works for most tasks except for inpaint)
QwenImageDecodeBlocks = InsertableDict(
[
("decode", QwenImageDecoderStep()),
("postprocess", QwenImageProcessImagesOutputStep()),
]
)
class QwenImageDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageDecodeBlocks.values()
block_names = QwenImageDecodeBlocks.keys()
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image."
#### QwenImage/text2image presets
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("input", QwenImageTextInputsStep()),
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
("denoise", QwenImageDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)
## 1.2 QwenImage/inpaint
#### QwenImage/inpaint vae encoder
QwenImageInpaintVaeEncoderBlocks = InsertableDict(
[
(
"preprocess",
QwenImageInpaintProcessImagesInputStep,
), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
]
)
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageInpaintVaeEncoderBlocks.values()
block_names = QwenImageInpaintVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return (
"This step is used for processing image and mask inputs for inpainting tasks. It:\n"
" - Resizes the image to the target size, based on `height` and `width`.\n"
" - Processes and updates `image` and `mask_image`.\n"
" - Creates `image_latents`."
)
#### QwenImage/inpaint inputs
QwenImageInpaintInputBlocks = InsertableDict(
[
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
(
"additional_inputs",
QwenImageInputsDynamicStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
),
),
]
)
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageInpaintInputBlocks.values()
block_names = QwenImageInpaintInputBlocks.keys()
@property
def description(self):
return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
# QwenImage/inpaint prepare latents
QwenImageInpaintPrepareLatentsBlocks = InsertableDict(
[
("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()),
("create_mask_latents", QwenImageCreateMaskLatentsStep()),
]
)
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageInpaintPrepareLatentsBlocks.values()
block_names = QwenImageInpaintPrepareLatentsBlocks.keys()
@property
def description(self) -> str:
return (
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
" - Add noise to the image latents to create the latents input for the denoiser.\n"
" - Create the pachified latents `mask` based on the processedmask image.\n"
)
#### QwenImage/inpaint decode
QwenImageInpaintDecodeBlocks = InsertableDict(
[
("decode", QwenImageDecoderStep()),
("postprocess", QwenImageInpaintProcessImagesOutputStep()),
]
)
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageInpaintDecodeBlocks.values()
block_names = QwenImageInpaintDecodeBlocks.keys()
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
#### QwenImage/inpaint presets
INPAINT_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("vae_encoder", QwenImageInpaintVaeEncoderStep()),
("input", QwenImageInpaintInputStep()),
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
("denoise", QwenImageInpaintDenoiseStep()),
("decode", QwenImageInpaintDecodeStep()),
]
)
## 1.3 QwenImage/img2img
#### QwenImage/img2img vae encoder
QwenImageImg2ImgVaeEncoderBlocks = InsertableDict(
[
("preprocess", QwenImageProcessImagesInputStep()),
("encode", QwenImageVaeEncoderDynamicStep()),
]
)
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageImg2ImgVaeEncoderBlocks.values()
block_names = QwenImageImg2ImgVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
#### QwenImage/img2img inputs
QwenImageImg2ImgInputBlocks = InsertableDict(
[
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
]
)
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageImg2ImgInputBlocks.values()
block_names = QwenImageImg2ImgInputBlocks.keys()
@property
def description(self):
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
#### QwenImage/img2img presets
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("vae_encoder", QwenImageImg2ImgVaeEncoderStep()),
("input", QwenImageImg2ImgInputStep()),
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
("denoise", QwenImageDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)
## 1.4 QwenImage/controlnet
#### QwenImage/controlnet presets
CONTROLNET_BLOCKS = InsertableDict(
[
("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image
("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet
(
"controlnet_before_denoise",
QwenImageControlNetBeforeDenoiserStep(),
), # before denoise step (after set_timesteps step)
(
"controlnet_denoise_loop_before",
QwenImageLoopBeforeDenoiserControlNet(),
), # controlnet loop step (insert before the denoiseloop_denoiser)
]
)
## 1.5 QwenImage/auto encoders
#### for inpaint and img2img tasks
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
+ " - if `mask_image` or `image` is not provided, step will be skipped."
)
# for controlnet tasks
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageControlNetVaeEncoderStep]
block_names = ["controlnet"]
block_trigger_inputs = ["control_image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
+ " - if `control_image` is not provided, step will be skipped."
)
## 1.6 QwenImage/auto inputs
# text2image/inpaint/img2img
class QwenImageAutoInputStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep]
block_names = ["inpaint", "img2img", "text2image"]
block_trigger_inputs = ["processed_mask_image", "image_latents", None]
@property
def description(self):
return (
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
" This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n"
+ " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
)
# controlnet
class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks):
block_classes = [QwenImageControlNetInputsStep]
block_names = ["controlnet"]
block_trigger_inputs = ["control_image_latents"]
@property
def description(self):
return (
"Controlnet input step that prepare the control_image_latents input.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ " - if `control_image_latents` is not provided, step will be skipped."
)
## 1.7 QwenImage/auto before denoise step
# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step
# QwenImage/text2image before denoise
QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
]
)
class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values()
block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task."
# QwenImage/inpaint before denoise
QwenImageInpaintBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
]
)
class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageInpaintBeforeDenoiseBlocks.values()
block_names = QwenImageInpaintBeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
# QwenImage/img2img before denoise
QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
]
)
class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values()
block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
# auto before_denoise step for text2image, inpaint, img2img tasks
class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [
QwenImageInpaintBeforeDenoiseStep,
QwenImageImg2ImgBeforeDenoiseStep,
QwenImageText2ImageBeforeDenoiseStep,
]
block_names = ["inpaint", "img2img", "text2image"]
block_trigger_inputs = ["processed_mask_image", "image_latents", None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n"
+ " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
+ " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
)
# auto before_denoise step for controlnet tasks
class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [QwenImageControlNetBeforeDenoiserStep]
block_names = ["controlnet"]
block_trigger_inputs = ["control_image_latents"]
@property
def description(self):
return (
"Controlnet before denoise step that prepare the controlnet input.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ " - if `control_image_latents` is not provided, step will be skipped."
)
## 1.8 QwenImage/auto denoise
# auto denoise step for controlnet tasks: works for all tasks with controlnet
class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep]
block_names = ["inpaint_denoise", "denoise"]
block_trigger_inputs = ["mask", None]
@property
def description(self):
return (
"Controlnet step during the denoising process. \n"
" This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n"
+ " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n"
+ " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n"
)
# auto denoise step for everything: works for all tasks with or without controlnet
class QwenImageAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
QwenImageControlNetAutoDenoiseStep,
QwenImageInpaintDenoiseStep,
QwenImageDenoiseStep,
]
block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
block_trigger_inputs = ["control_image_latents", "mask", None]
@property
def description(self):
return (
"Denoise step that iteratively denoise the latents. \n"
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n"
+ " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n"
+ " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n"
)
## 1.9 QwenImage/auto decode
# auto decode step for inpaint and text2image tasks
class QwenImageAutoDecodeStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
block_names = ["inpaint_decode", "decode"]
block_trigger_inputs = ["mask", None]
@property
def description(self):
return (
"Decode step that decode the latents into images. \n"
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
+ " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
+ " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
)
## 1.10 QwenImage/auto block & presets
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("vae_encoder", QwenImageAutoVaeEncoderStep()),
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
("input", QwenImageAutoInputStep()),
("controlnet_input", QwenImageOptionalControlNetInputStep()),
("before_denoise", QwenImageAutoBeforeDenoiseStep()),
("controlnet_before_denoise", QwenImageOptionalControlNetBeforeDenoiseStep()),
("denoise", QwenImageAutoDenoiseStep()),
("decode", QwenImageAutoDecodeStep()),
]
)
class QwenImageAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ "- for image-to-image generation, you need to provide `image`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
# 2. QwenImage-Edit
## 2.1 QwenImage-Edit/edit
#### QwenImage-Edit/edit vl encoder: take both image and text prompts
QwenImageEditVLEncoderBlocks = InsertableDict(
[
("resize", QwenImageEditResizeDynamicStep()),
("encode", QwenImageEditTextEncoderStep()),
]
)
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageEditVLEncoderBlocks.values()
block_names = QwenImageEditVLEncoderBlocks.keys()
@property
def description(self) -> str:
return "QwenImage-Edit VL encoder step that encode the image an text prompts together."
#### QwenImage-Edit/edit vae encoder
QwenImageEditVaeEncoderBlocks = InsertableDict(
[
("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step
("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
]
)
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageEditVaeEncoderBlocks.values()
block_names = QwenImageEditVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "Vae encoder step that encode the image inputs into their latent representations."
#### QwenImage-Edit/edit input
QwenImageEditInputBlocks = InsertableDict(
[
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
]
)
class QwenImageEditInputStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageEditInputBlocks.values()
block_names = QwenImageEditInputBlocks.keys()
@property
def description(self):
return "Input step that prepares the inputs for the edit denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs: \n"
" - `image_latents`.\n"
" - update height/width based `image_latents`, patchify `image_latents`."
#### QwenImage/edit presets
EDIT_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditVLEncoderStep()),
("vae_encoder", QwenImageEditVaeEncoderStep()),
("input", QwenImageEditInputStep()),
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
("denoise", QwenImageEditDenoiseStep()),
("decode", QwenImageDecodeStep()),
]
)
## 2.2 QwenImage-Edit/edit inpaint
#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step
QwenImageEditInpaintVaeEncoderBlocks = InsertableDict(
[
("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image
(
"preprocess",
QwenImageInpaintProcessImagesInputStep,
), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
(
"encode",
QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
), # processed_image -> image_latents
]
)
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageEditInpaintVaeEncoderBlocks.values()
block_names = QwenImageEditInpaintVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return (
"This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
" - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
" - process the resized image and mask image.\n"
" - create image latents."
)
#### QwenImage-Edit/edit inpaint presets
EDIT_INPAINT_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditVLEncoderStep()),
("vae_encoder", QwenImageEditInpaintVaeEncoderStep()),
("input", QwenImageInpaintInputStep()),
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
("denoise", QwenImageEditInpaintDenoiseStep()),
("decode", QwenImageInpaintDecodeStep()),
]
)
## 2.3 QwenImage-Edit/auto encoders
class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [
QwenImageEditInpaintVaeEncoderStep,
QwenImageEditVaeEncoderStep,
]
block_names = ["edit_inpaint", "edit"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations. \n"
" This is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
+ " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
+ " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
+ " - if `mask_image` or `image` is not provided, step will be skipped."
)
## 2.4 QwenImage-Edit/auto inputs
class QwenImageEditAutoInputStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
block_names = ["edit_inpaint", "edit"]
block_trigger_inputs = ["processed_mask_image", "image"]
@property
def description(self):
return (
"Input step that prepares the inputs for the edit denoising step.\n"
+ " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
+ " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
+ " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n"
+ " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
)
## 2.5 QwenImage-Edit/auto before denoise
# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step
#### QwenImage-Edit/edit before denoise
QwenImageEditBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsStep()),
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
]
)
class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageEditBeforeDenoiseBlocks.values()
block_names = QwenImageEditBeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
#### QwenImage-Edit/edit inpaint before denoise
QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", QwenImagePrepareLatentsStep()),
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
]
)
class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values()
block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task."
# auto before_denoise step for edit and edit_inpaint tasks
class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditInpaintBeforeDenoiseStep,
QwenImageEditBeforeDenoiseStep,
]
block_names = ["edit_inpaint", "edit"]
block_trigger_inputs = ["processed_mask_image", "image_latents"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n"
+ " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
+ " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped."
)
## 2.6 QwenImage-Edit/auto denoise
class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep]
block_names = ["inpaint_denoise", "denoise"]
block_trigger_inputs = ["processed_mask_image", "image_latents"]
@property
def description(self):
return (
"Denoise step that iteratively denoise the latents. \n"
+ "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n"
+ " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
+ " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
)
## 2.7 QwenImage-Edit/auto blocks & presets
EDIT_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditVLEncoderStep()),
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
("input", QwenImageEditAutoInputStep()),
("before_denoise", QwenImageEditAutoBeforeDenoiseStep()),
("denoise", QwenImageEditAutoDenoiseStep()),
("decode", QwenImageAutoDecodeStep()),
]
)
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = EDIT_AUTO_BLOCKS.values()
block_names = EDIT_AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
+ "- for edit (img2img) generation, you need to provide `image`\n"
+ "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
)
# 3. all block presets supported in QwenImage & QwenImage-Edit
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"edit": EDIT_BLOCKS,
"edit_inpaint": EDIT_INPAINT_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"auto": AUTO_BLOCKS,
"edit_auto": EDIT_AUTO_BLOCKS,
}

View File

@@ -0,0 +1,202 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import QwenImageLoraLoaderMixin
from ..modular_pipeline import ModularPipeline
class QwenImagePachifier(ConfigMixin):
"""
A class to pack and unpack latents for QwenImage.
"""
config_name = "config.json"
@register_to_config
def __init__(
self,
patch_size: int = 2,
):
super().__init__()
def pack_latents(self, latents):
if latents.ndim != 4 and latents.ndim != 5:
raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}")
if latents.ndim == 4:
latents = latents.unsqueeze(2)
batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape
patch_size = self.config.patch_size
if latent_height % patch_size != 0 or latent_width % patch_size != 0:
raise ValueError(
f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
)
latents = latents.view(
batch_size,
num_channels_latents,
latent_height // patch_size,
patch_size,
latent_width // patch_size,
patch_size,
)
latents = latents.permute(
0, 2, 4, 1, 3, 5
) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size
latents = latents.reshape(
batch_size,
(latent_height // patch_size) * (latent_width // patch_size),
num_channels_latents * patch_size * patch_size,
)
return latents
def unpack_latents(self, latents, height, width, vae_scale_factor=8):
if latents.ndim != 3:
raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
batch_size, num_patches, channels = latents.shape
patch_size = self.config.patch_size
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
latents = latents.view(
batch_size,
height // patch_size,
width // patch_size,
channels // (patch_size * patch_size),
patch_size,
patch_size,
)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
return latents
class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
"""
A ModularPipeline for QwenImage.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 16
if hasattr(self, "transformer") and self.transformer is not None:
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
@property
def is_guidance_distilled(self):
is_guidance_distilled = False
if hasattr(self, "transformer") and self.transformer is not None:
is_guidance_distilled = self.transformer.config.guidance_embeds
return is_guidance_distilled
@property
def requires_unconditional_embeds(self):
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds
class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
"""
A ModularPipeline for QwenImage-Edit.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
# YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step.
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 16
if hasattr(self, "transformer") and self.transformer is not None:
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
@property
def is_guidance_distilled(self):
is_guidance_distilled = False
if hasattr(self, "transformer") and self.transformer is not None:
is_guidance_distilled = self.transformer.config.guidance_embeds
return is_guidance_distilled
@property
def requires_unconditional_embeds(self):
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds

View File

@@ -76,6 +76,7 @@ class StableDiffusionXLModularPipeline(
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
# YiYi TODO: change to num_channels_latents
@property
def num_channels_unet(self):
num_channels_unet = 4

View File

@@ -91,6 +91,14 @@ from .pag import (
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import (
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImagePipeline,
)
from .sana import SanaPipeline
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
from .stable_diffusion import (
@@ -150,6 +158,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("cogview4-control", CogView4ControlPipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),
]
)
@@ -174,6 +184,8 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
("flux-kontext", FluxKontextPipeline),
("qwenimage", QwenImageImg2ImgPipeline),
("qwenimage-edit", QwenImageEditPipeline),
]
)
@@ -195,6 +207,8 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetInpaintPipeline),
("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
("qwenimage", QwenImageInpaintPipeline),
("qwenimage-edit", QwenImageEditInpaintPipeline),
]
)

View File

@@ -32,6 +32,66 @@ class FluxModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class QwenImageAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class QwenImageEditAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class QwenImageEditModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class QwenImageModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]