diff --git a/docs/source/en/api/image_processor.md b/docs/source/en/api/image_processor.md
index 3e75af026d..82d1837b0b 100644
--- a/docs/source/en/api/image_processor.md
+++ b/docs/source/en/api/image_processor.md
@@ -20,6 +20,12 @@ All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or Nu
[[autodoc]] image_processor.VaeImageProcessor
+## InpaintProcessor
+
+The [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay.
+
+[[autodoc]] image_processor.InpaintProcessor
+
## VaeImageProcessorLDM3D
The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.
diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
index 87e0d2c29e..03c05a05e0 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
@@ -29,8 +29,9 @@ from pathlib import Path
import numpy as np
import torch
import transformers
-from accelerate import Accelerator
+from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
@@ -1222,6 +1223,9 @@ def main(args):
kwargs_handlers=[kwargs],
)
+ if accelerator.distributed_type == DistributedType.DEEPSPEED:
+ AcceleratorState().deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
+
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
@@ -1438,17 +1442,20 @@ def main(args):
text_encoder_one_lora_layers_to_save = None
modules_to_save = {}
for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ model = unwrap_model(model)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ model = unwrap_model(model)
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["text_encoder"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
- weights.pop()
+ if weights:
+ weights.pop()
FluxKontextPipeline.save_lora_weights(
output_dir,
@@ -1461,15 +1468,25 @@ def main(args):
transformer_ = None
text_encoder_one_ = None
- while len(models) > 0:
- model = models.pop()
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
- text_encoder_one_ = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = unwrap_model(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ else:
+ transformer_ = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer"
+ )
+ transformer_.add_adapter(transformer_lora_config)
+ text_encoder_one_ = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder"
+ )
lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir)
@@ -2069,7 +2086,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
- if accelerator.is_main_process:
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index fa5dd6482c..d96acc3818 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -385,6 +385,10 @@ else:
[
"FluxAutoBlocks",
"FluxModularPipeline",
+ "QwenImageAutoBlocks",
+ "QwenImageEditAutoBlocks",
+ "QwenImageEditModularPipeline",
+ "QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"WanAutoBlocks",
@@ -506,6 +510,7 @@ else:
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
+ "QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
"QwenImageEditPipeline",
@@ -1038,6 +1043,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modular_pipelines import (
FluxAutoBlocks,
FluxModularPipeline,
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditModularPipeline,
+ QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
WanAutoBlocks,
@@ -1155,6 +1164,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
+ QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py
index b7a74be2e5..f6e5bdd52d 100644
--- a/src/diffusers/hooks/_helpers.py
+++ b/src/diffusers/hooks/_helpers.py
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
+ from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
# AttnProcessor2_0
@@ -140,6 +141,14 @@ def _register_attention_processors_metadata():
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
)
+ # QwenDoubleStreamAttnProcessor2
+ AttentionProcessorRegistry.register(
+ model_class=QwenDoubleStreamAttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
+ ),
+ )
+
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -298,4 +307,5 @@ _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___h
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
# fmt: on
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index 6a3cf77a7d..0e3082eada 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -523,6 +523,7 @@ class VaeImageProcessor(ConfigMixin):
size=(height, width),
)
image = self.pt_to_numpy(image)
+
return image
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
@@ -838,6 +839,137 @@ class VaeImageProcessor(ConfigMixin):
return image
+class InpaintProcessor(ConfigMixin):
+ """
+ Image processor for inpainting image and mask.
+ """
+
+ config_name = CONFIG_NAME
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 8,
+ vae_latent_channels: int = 4,
+ resample: str = "lanczos",
+ reducing_gap: int = None,
+ do_normalize: bool = True,
+ do_binarize: bool = False,
+ do_convert_grayscale: bool = False,
+ mask_do_normalize: bool = False,
+ mask_do_binarize: bool = True,
+ mask_do_convert_grayscale: bool = True,
+ ):
+ super().__init__()
+
+ self._image_processor = VaeImageProcessor(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ resample=resample,
+ reducing_gap=reducing_gap,
+ do_normalize=do_normalize,
+ do_binarize=do_binarize,
+ do_convert_grayscale=do_convert_grayscale,
+ )
+ self._mask_processor = VaeImageProcessor(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ resample=resample,
+ reducing_gap=reducing_gap,
+ do_normalize=mask_do_normalize,
+ do_binarize=mask_do_binarize,
+ do_convert_grayscale=mask_do_convert_grayscale,
+ )
+
+ def preprocess(
+ self,
+ image: PIL.Image.Image,
+ mask: PIL.Image.Image = None,
+ height: int = None,
+ width: int = None,
+ padding_mask_crop: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Preprocess the image and mask.
+ """
+ if mask is None and padding_mask_crop is not None:
+ raise ValueError("mask must be provided if padding_mask_crop is provided")
+
+ # if mask is None, same behavior as regular image processor
+ if mask is None:
+ return self._image_processor.preprocess(image, height=height, width=width)
+
+ if padding_mask_crop is not None:
+ crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ processed_image = self._image_processor.preprocess(
+ image,
+ height=height,
+ width=width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+
+ processed_mask = self._mask_processor.preprocess(
+ mask,
+ height=height,
+ width=width,
+ resize_mode=resize_mode,
+ crops_coords=crops_coords,
+ )
+
+ if crops_coords is not None:
+ postprocessing_kwargs = {
+ "crops_coords": crops_coords,
+ "original_image": image,
+ "original_mask": mask,
+ }
+ else:
+ postprocessing_kwargs = {
+ "crops_coords": None,
+ "original_image": None,
+ "original_mask": None,
+ }
+
+ return processed_image, processed_mask, postprocessing_kwargs
+
+ def postprocess(
+ self,
+ image: torch.Tensor,
+ output_type: str = "pil",
+ original_image: Optional[PIL.Image.Image] = None,
+ original_mask: Optional[PIL.Image.Image] = None,
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
+ ) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
+ """
+ Postprocess the image, optionally apply mask overlay
+ """
+ image = self._image_processor.postprocess(
+ image,
+ output_type=output_type,
+ )
+ # optionally apply the mask overlay
+ if crops_coords is not None and (original_image is None or original_mask is None):
+ raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
+
+ elif crops_coords is not None and output_type != "pil":
+ raise ValueError("output_type must be 'pil' if crops_coords is provided")
+
+ elif crops_coords is not None:
+ image = [
+ self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
+ ]
+
+ return image
+
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
"""
Image processor for VAE LDM3D.
diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py
index 68d707f9e0..65c22b349b 100644
--- a/src/diffusers/modular_pipelines/__init__.py
+++ b/src/diffusers/modular_pipelines/__init__.py
@@ -47,6 +47,12 @@ else:
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
+ _import_structure["qwenimage"] = [
+ "QwenImageAutoBlocks",
+ "QwenImageModularPipeline",
+ "QwenImageEditModularPipeline",
+ "QwenImageEditAutoBlocks",
+ ]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -68,6 +74,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
+ from .qwenimage import (
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditModularPipeline,
+ QwenImageModularPipeline,
+ )
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import WanAutoBlocks, WanModularPipeline
else:
diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py
index c0524a1f86..78226a49b1 100644
--- a/src/diffusers/modular_pipelines/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/modular_pipeline.py
@@ -56,6 +56,8 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
+ ("qwenimage", "QwenImageModularPipeline"),
+ ("qwenimage-edit", "QwenImageEditModularPipeline"),
]
)
@@ -64,6 +66,8 @@ MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
("WanModularPipeline", "WanAutoBlocks"),
("FluxModularPipeline", "FluxAutoBlocks"),
+ ("QwenImageModularPipeline", "QwenImageAutoBlocks"),
+ ("QwenImageEditModularPipeline", "QwenImageEditAutoBlocks"),
]
)
@@ -133,8 +137,8 @@ class PipelineState:
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
intermediates dict.
"""
- if name in self.intermediates:
- return self.intermediates[name]
+ if name in self.values:
+ return self.values[name]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __repr__(self):
@@ -548,8 +552,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
def __init__(self):
sub_blocks = InsertableDict()
- for block_name, block_cls in zip(self.block_names, self.block_classes):
- sub_blocks[block_name] = block_cls()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(
@@ -830,7 +837,9 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
return expected_configs
@classmethod
- def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks":
+ def from_blocks_dict(
+ cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
+ ) -> "SequentialPipelineBlocks":
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
Args:
@@ -852,12 +861,19 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
instance.block_classes = [block.__class__ for block in sub_blocks.values()]
instance.block_names = list(sub_blocks.keys())
instance.sub_blocks = sub_blocks
+
+ if description is not None:
+ instance.description = description
+
return instance
def __init__(self):
sub_blocks = InsertableDict()
- for block_name, block_cls in zip(self.block_names, self.block_classes):
- sub_blocks[block_name] = block_cls()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
def _get_inputs(self):
@@ -1280,8 +1296,11 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def __init__(self):
sub_blocks = InsertableDict()
- for block_name, block_cls in zip(self.block_names, self.block_classes):
- sub_blocks[block_name] = block_cls()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
@classmethod
diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py
new file mode 100644
index 0000000000..81cf515730
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py
@@ -0,0 +1,75 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["encoders"] = ["QwenImageTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "AUTO_BLOCKS",
+ "CONTROLNET_BLOCKS",
+ "EDIT_AUTO_BLOCKS",
+ "EDIT_BLOCKS",
+ "EDIT_INPAINT_BLOCKS",
+ "IMAGE2IMAGE_BLOCKS",
+ "INPAINT_BLOCKS",
+ "TEXT2IMAGE_BLOCKS",
+ "QwenImageAutoBlocks",
+ "QwenImageEditAutoBlocks",
+ ]
+ _import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .encoders import (
+ QwenImageTextEncoderStep,
+ )
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ AUTO_BLOCKS,
+ CONTROLNET_BLOCKS,
+ EDIT_AUTO_BLOCKS,
+ EDIT_BLOCKS,
+ EDIT_INPAINT_BLOCKS,
+ IMAGE2IMAGE_BLOCKS,
+ INPAINT_BLOCKS,
+ TEXT2IMAGE_BLOCKS,
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ )
+ from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
new file mode 100644
index 0000000000..738a1e5d15
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
@@ -0,0 +1,727 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils.torch_utils import randn_tensor, unwrap_module
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+def get_timesteps(scheduler, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = scheduler.timesteps[t_start * scheduler.order :]
+ if hasattr(scheduler, "set_begin_index"):
+ scheduler.set_begin_index(t_start * scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+
+# Prepare Latents steps
+
+
+class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Prepare initial random noise for the generation process"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="generator"),
+ InputParam(
+ name="batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ name="dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of the model inputs, can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="latents",
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ height=block_state.height,
+ width=block_state.width,
+ vae_scale_factor=components.vae_scale_factor,
+ )
+
+ device = components._execution_device
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+
+ # we can update the height and width here since it's used to generate the initial
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+
+ shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width)
+ if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ block_state.latents = randn_tensor(
+ shape, generator=block_state.generator, device=device, dtype=block_state.dtype
+ )
+ block_state.latents = components.pachifier.pack_latents(block_state.latents)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial random noised, can be generated in prepare latent step.",
+ ),
+ InputParam(
+ name="image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
+ ),
+ InputParam(
+ name="timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="initial_noise",
+ type_hint=torch.Tensor,
+ description="The initial random noised used for inpainting denoising.",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(image_latents, latents):
+ if image_latents.shape[0] != latents.shape[0]:
+ raise ValueError(
+ f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
+ )
+
+ if image_latents.ndim != 3:
+ raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ image_latents=block_state.image_latents,
+ latents=block_state.latents,
+ )
+
+ # prepare latent timestep
+ latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
+
+ # make copy of initial_noise
+ block_state.initial_noise = block_state.latents
+
+ # scale noise
+ block_state.latents = components.scheduler.scale_noise(
+ block_state.image_latents, latent_timestep, block_state.latents
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name="processed_mask_image",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The processed mask to use for the inpainting process.",
+ ),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="dtype", required=True),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+
+ height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+
+ block_state.mask = torch.nn.functional.interpolate(
+ block_state.processed_mask_image,
+ size=(height_latents, width_latents),
+ )
+
+ block_state.mask = block_state.mask.unsqueeze(2)
+ block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1)
+ block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype)
+
+ block_state.mask = components.pachifier.pack_latents(block_state.mask)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# Set Timesteps steps
+
+
+class QwenImageSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_inference_steps", default=50),
+ InputParam(name="sigmas"),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process, used to calculate the image sequence length.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ sigmas = (
+ np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
+ if block_state.sigmas is None
+ else block_state.sigmas
+ )
+
+ mu = calculate_shift(
+ image_seq_len=block_state.latents.shape[1],
+ base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
+ max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
+ base_shift=components.scheduler.config.get("base_shift", 0.5),
+ max_shift=components.scheduler.config.get("max_shift", 1.15),
+ )
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ device=device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ components.scheduler.set_begin_index(0)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_inference_steps", default=50),
+ InputParam(name="sigmas"),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process, used to calculate the image sequence length.",
+ ),
+ InputParam(name="strength", default=0.9),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="timesteps",
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ sigmas = (
+ np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
+ if block_state.sigmas is None
+ else block_state.sigmas
+ )
+
+ mu = calculate_shift(
+ image_seq_len=block_state.latents.shape[1],
+ base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
+ max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
+ base_shift=components.scheduler.config.get("base_shift", 0.5),
+ max_shift=components.scheduler.config.get("max_shift", 1.15),
+ )
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ device=device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ block_state.timesteps, block_state.num_inference_steps = get_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ strength=block_state.strength,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# other inputs for denoiser
+
+## RoPE inputs for denoiser
+
+
+class QwenImageRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds_mask"),
+ InputParam(name="negative_prompt_embeds_mask"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="img_shapes",
+ type_hint=List[List[Tuple[int, int, int]]],
+ description="The shapes of the images latents, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="negative_txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.img_shapes = [
+ [
+ (
+ 1,
+ block_state.height // components.vae_scale_factor // 2,
+ block_state.width // components.vae_scale_factor // 2,
+ )
+ ]
+ * block_state.batch_size
+ ]
+ block_state.txt_seq_lens = (
+ block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
+ )
+ block_state.negative_txt_seq_lens = (
+ block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
+ if block_state.negative_prompt_embeds_mask is not None
+ else None
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="batch_size", required=True),
+ InputParam(
+ name="resized_image", required=True, type_hint=torch.Tensor, description="The resized image input"
+ ),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds_mask"),
+ InputParam(name="negative_prompt_embeds_mask"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="img_shapes",
+ type_hint=List[List[Tuple[int, int, int]]],
+ description="The shapes of the images latents, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="negative_txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # for edit, image size can be different from the target size (height/width)
+ image = (
+ block_state.resized_image[0] if isinstance(block_state.resized_image, list) else block_state.resized_image
+ )
+ image_width, image_height = image.size
+
+ block_state.img_shapes = [
+ [
+ (
+ 1,
+ block_state.height // components.vae_scale_factor // 2,
+ block_state.width // components.vae_scale_factor // 2,
+ ),
+ (1, image_height // components.vae_scale_factor // 2, image_width // components.vae_scale_factor // 2),
+ ]
+ ] * block_state.batch_size
+
+ block_state.txt_seq_lens = (
+ block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
+ )
+ block_state.negative_txt_seq_lens = (
+ block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
+ if block_state.negative_prompt_embeds_mask is not None
+ else None
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+## ControlNet inputs for denoiser
+class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("control_guidance_start", default=0.0),
+ InputParam("control_guidance_end", default=1.0),
+ InputParam("controlnet_conditioning_scale", default=1.0),
+ InputParam("control_image_latents", required=True),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ controlnet = unwrap_module(components.controlnet)
+
+ # control_guidance_start/control_guidance_end (align format)
+ if not isinstance(block_state.control_guidance_start, list) and isinstance(
+ block_state.control_guidance_end, list
+ ):
+ block_state.control_guidance_start = len(block_state.control_guidance_end) * [
+ block_state.control_guidance_start
+ ]
+ elif not isinstance(block_state.control_guidance_end, list) and isinstance(
+ block_state.control_guidance_start, list
+ ):
+ block_state.control_guidance_end = len(block_state.control_guidance_start) * [
+ block_state.control_guidance_end
+ ]
+ elif not isinstance(block_state.control_guidance_start, list) and not isinstance(
+ block_state.control_guidance_end, list
+ ):
+ mult = (
+ len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1
+ )
+ block_state.control_guidance_start, block_state.control_guidance_end = (
+ mult * [block_state.control_guidance_start],
+ mult * [block_state.control_guidance_end],
+ )
+
+ # controlnet_conditioning_scale (align format)
+ if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance(
+ block_state.controlnet_conditioning_scale, float
+ ):
+ block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult
+
+ # controlnet_keep
+ block_state.controlnet_keep = []
+ for i in range(len(block_state.timesteps)):
+ keeps = [
+ 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
+ for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
+ ]
+ block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py
new file mode 100644
index 0000000000..6c82fe989e
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py
@@ -0,0 +1,203 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import InpaintProcessor, VaeImageProcessor
+from ...models import AutoencoderKLQwenImage
+from ...utils import logging
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+logger = logging.get_logger(__name__)
+
+
+class QwenImageDecoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the latents to images"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to decode, can be generated in the denoise step",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
+ block_state.latents = components.pachifier.unpack_latents(
+ block_state.latents, block_state.height, block_state.width
+ )
+ block_state.latents = block_state.latents.to(components.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(components.vae.config.latents_mean)
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
+ .to(block_state.latents.device, block_state.latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
+ 1, components.vae.config.z_dim, 1, 1, 1
+ ).to(block_state.latents.device, block_state.latents.dtype)
+ block_state.latents = block_state.latents / latents_std + latents_mean
+ block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0]
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "postprocess the generated image"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("images", required=True, description="the generated image from decoders step"),
+ InputParam(
+ name="output_type",
+ default="pil",
+ type_hint=str,
+ description="The type of the output images, can be 'pil', 'np', 'pt'",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(output_type):
+ if output_type not in ["pil", "np", "pt"]:
+ raise ValueError(f"Invalid output_type: {output_type}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.output_type)
+
+ block_state.images = components.image_processor.postprocess(
+ image=block_state.images,
+ output_type=block_state.output_type,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "postprocess the generated image, optional apply the mask overally to the original image.."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_mask_processor",
+ InpaintProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("images", required=True, description="the generated image from decoders step"),
+ InputParam(
+ name="output_type",
+ default="pil",
+ type_hint=str,
+ description="The type of the output images, can be 'pil', 'np', 'pt'",
+ ),
+ InputParam("mask_overlay_kwargs"),
+ ]
+
+ @staticmethod
+ def check_inputs(output_type, mask_overlay_kwargs):
+ if output_type not in ["pil", "np", "pt"]:
+ raise ValueError(f"Invalid output_type: {output_type}")
+
+ if mask_overlay_kwargs and output_type != "pil":
+ raise ValueError("only support output_type 'pil' for mask overlay")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs)
+
+ if block_state.mask_overlay_kwargs is None:
+ mask_overlay_kwargs = {}
+ else:
+ mask_overlay_kwargs = block_state.mask_overlay_kwargs
+
+ block_state.images = components.image_mask_processor.postprocess(
+ image=block_state.images,
+ **mask_overlay_kwargs,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py
new file mode 100644
index 0000000000..d0704ee6e0
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py
@@ -0,0 +1,668 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Tuple
+
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging
+from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # one timestep
+ block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
+ block_state.latent_model_input = block_state.latents
+ return components, block_state
+
+
+class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # one timestep
+
+ block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1)
+ block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
+ return components, block_state
+
+
+class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that runs the controlnet before the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "control_image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "controlnet_conditioning_scale",
+ type_hint=float,
+ description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "controlnet_keep",
+ required=True,
+ type_hint=List[float],
+ description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description=(
+ "All conditional model inputs for the denoiser. "
+ "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
+ ),
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int):
+ # cond_scale for the timestep (controlnet input)
+ if isinstance(block_state.controlnet_keep[i], list):
+ block_state.cond_scale = [
+ c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])
+ ]
+ else:
+ controlnet_cond_scale = block_state.controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
+
+ # run controlnet for the guidance batch
+ controlnet_block_samples = components.controlnet(
+ hidden_states=block_state.latent_model_input,
+ controlnet_cond=block_state.control_image_latents,
+ conditioning_scale=block_state.cond_scale,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ encoder_hidden_states=block_state.prompt_embeds,
+ encoder_hidden_states_mask=block_state.prompt_embeds_mask,
+ txt_seq_lens=block_state.txt_seq_lens,
+ return_dict=False,
+ )
+
+ block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples
+
+ return components, block_state
+
+
+class QwenImageLoopDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that denoise the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", QwenImageTransformer2DModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
+ ),
+ InputParam(
+ "img_shapes",
+ required=True,
+ type_hint=List[Tuple[int, int]],
+ description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ guider_input_fields = {
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ "encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
+ "txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
+ }
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+ guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.transformer)
+ cond_kwargs = guider_state_batch.as_dict()
+ cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+
+ # YiYi TODO: add cache context
+ guider_state_batch.noise_pred = components.transformer(
+ hidden_states=block_state.latent_model_input,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ **block_state.additional_cond_kwargs,
+ )[0]
+
+ components.guider.cleanup_models(components.transformer)
+
+ guider_output = components.guider(guider_state)
+
+ # apply guidance rescale
+ pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True)
+ pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True)
+ block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm)
+
+ return components, block_state
+
+
+class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that denoise the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", QwenImageTransformer2DModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
+ ),
+ InputParam(
+ "img_shapes",
+ required=True,
+ type_hint=List[Tuple[int, int]],
+ description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ guider_input_fields = {
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ "encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
+ "txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
+ }
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+ guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.transformer)
+ cond_kwargs = guider_state_batch.as_dict()
+ cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+
+ # YiYi TODO: add cache context
+ guider_state_batch.noise_pred = components.transformer(
+ hidden_states=block_state.latent_model_input,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ **block_state.additional_cond_kwargs,
+ )[0]
+
+ components.guider.cleanup_models(components.transformer)
+
+ guider_output = components.guider(guider_state)
+
+ pred = guider_output.pred[:, : block_state.latents.size(1)]
+ pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)]
+
+ # apply guidance rescale
+ pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True)
+ pred_norm = torch.norm(pred, dim=-1, keepdim=True)
+ block_state.noise_pred = pred * (pred_cond_norm / pred_norm)
+
+ return components, block_state
+
+
+class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that updates the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ block_state.latents = block_state.latents.to(latents_dtype)
+
+ return components, block_state
+
+
+class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that updates the latents using mask and image_latents for inpainting. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "mask",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "initial_noise",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ block_state.init_latents_proper = block_state.image_latents
+ if i < len(block_state.timesteps) - 1:
+ block_state.noise_timestep = block_state.timesteps[i + 1]
+ block_state.init_latents_proper = components.scheduler.scale_noise(
+ block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise
+ )
+
+ block_state.latents = (
+ 1 - block_state.mask
+ ) * block_state.init_latents_proper + block_state.mask * block_state.latents
+
+ return components, block_state
+
+
+class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def loop_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+
+ block_state.additional_cond_kwargs = {}
+
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# composing the denoising loops
+class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports text2image and image2image tasks for QwenImage."
+ )
+
+
+# composing the inpainting denoising loops
+class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks for QwenImage."
+ )
+
+
+# composing the controlnet denoising loops
+class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopBeforeDenoiserControlNet,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopBeforeDenoiserControlNet`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports text2img/img2img tasks with controlnet for QwenImage."
+ )
+
+
+# composing the controlnet denoising loops
+class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopBeforeDenoiserControlNet,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = [
+ "before_denoiser",
+ "before_denoiser_controlnet",
+ "denoiser",
+ "after_denoiser",
+ "after_denoiser_inpaint",
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopBeforeDenoiserControlNet`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks with controlnet for QwenImage."
+ )
+
+
+# composing the denoising loops
+class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageEditLoopBeforeDenoiser,
+ QwenImageEditLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageEditLoopBeforeDenoiser`\n"
+ " - `QwenImageEditLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports QwenImage Edit."
+ )
+
+
+class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageEditLoopBeforeDenoiser,
+ QwenImageEditLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageEditLoopBeforeDenoiser`\n"
+ " - `QwenImageEditLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks for QwenImage Edit."
+ )
diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py
new file mode 100644
index 0000000000..280fa6a152
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py
@@ -0,0 +1,857 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Optional, Union
+
+import PIL
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist
+from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
+from ...utils import logging
+from ...utils.torch_utils import unwrap_module
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+ return split_result
+
+
+def get_qwen_prompt_embeds(
+ text_encoder,
+ tokenizer,
+ prompt: Union[str, List[str]] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ prompt_template_encode_start_idx: int = 34,
+ tokenizer_max_length: int = 1024,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = prompt_template_encode
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = tokenizer(
+ txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+
+ split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+
+def get_qwen_prompt_embeds_edit(
+ text_encoder,
+ processor,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
+ prompt_template_encode_start_idx: int = 64,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = prompt_template_encode
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+
+ model_inputs = processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image
+def encode_vae_image(
+ image: torch.Tensor,
+ vae: AutoencoderKLQwenImage,
+ generator: torch.Generator,
+ device: torch.device,
+ dtype: torch.dtype,
+ latent_channels: int = 16,
+ sample_mode: str = "argmax",
+):
+ if not isinstance(image, torch.Tensor):
+ raise ValueError(f"Expected image to be a tensor, got {type(image)}.")
+
+ # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ image = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
+ latents_mean = (
+ torch.tensor(vae.config.latents_mean)
+ .view(1, latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(vae.config.latents_std)
+ .view(1, latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ image_latents = (image_latents - latents_mean) / latents_std
+
+ return image_latents
+
+
+class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
+ """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
+
+ This block resizes an input image tensor and exposes the resized result under configurable input and output
+ names. Use this when you need to wire the resize step to different image fields (e.g., "image",
+ "control_image")
+
+ Args:
+ input_name (str, optional): Name of the image field to read from the
+ pipeline state. Defaults to "image".
+ output_name (str, optional): Name of the resized image field to write
+ back to the pipeline state. Defaults to "resized_image".
+ """
+ if not isinstance(input_name, str) or not isinstance(output_name, str):
+ raise ValueError(
+ f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
+ )
+ self._image_input_name = input_name
+ self._resized_image_output_name = output_name
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_resize_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ images = getattr(block_state, self._image_input_name)
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if is_valid_image(images):
+ images = [images]
+
+ image_width, image_height = images[0].size
+ calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
+
+ resized_images = [
+ components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
+ for image in images
+ ]
+
+ setattr(block_state, self._resized_image_output_name, resized_images)
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageTextEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"),
+ ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=34),
+ ConfigSpec(name="tokenizer_max_length", default=1024),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
+ InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
+ InputParam(
+ name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The prompt embeddings",
+ ),
+ OutputParam(
+ name="prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The encoder attention mask",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings mask",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(prompt, negative_prompt, max_sequence_length):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ negative_prompt is not None
+ and not isinstance(negative_prompt, str)
+ and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length)
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds(
+ components.text_encoder,
+ components.tokenizer,
+ prompt=block_state.prompt,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ tokenizer_max_length=components.config.tokenizer_max_length,
+ device=device,
+ )
+
+ block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
+
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or ""
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
+ components.text_encoder,
+ components.tokenizer,
+ prompt=negative_prompt,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ tokenizer_max_length=components.config.tokenizer_max_length,
+ device=device,
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[
+ :, : block_state.max_sequence_length
+ ]
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[
+ :, : block_state.max_sequence_length
+ ]
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
+ ComponentSpec("processor", Qwen2VLProcessor),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=64),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
+ InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
+ InputParam(
+ name="resized_image",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image prompt to encode, should be resized using resize step",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The prompt embeddings",
+ ),
+ OutputParam(
+ name="prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The encoder attention mask",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings mask",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(prompt, negative_prompt):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ negative_prompt is not None
+ and not isinstance(negative_prompt, str)
+ and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.prompt, block_state.negative_prompt)
+
+ device = components._execution_device
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit(
+ components.text_encoder,
+ components.processor,
+ prompt=block_state.prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or ""
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
+ components.text_encoder,
+ components.processor,
+ prompt=negative_prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_mask_processor",
+ InpaintProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("mask_image", required=True),
+ InputParam("resized_image"),
+ InputParam("image"),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("padding_mask_crop"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="processed_image"),
+ OutputParam(name="processed_mask_image"),
+ OutputParam(
+ name="mask_overlay_kwargs",
+ type_hint=Dict,
+ description="The kwargs for the postprocess step to apply the mask overlay",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("resized_image and image cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = (
+ components.image_mask_processor.preprocess(
+ image=image,
+ mask=block_state.mask_image,
+ height=height,
+ width=width,
+ padding_mask_crop=block_state.padding_mask_crop,
+ )
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("resized_image"),
+ InputParam("image"),
+ InputParam("height"),
+ InputParam("width"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="processed_image"),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("resized_image and image cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image,
+ height=height,
+ width=width,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ input_name: str = "processed_image",
+ output_name: str = "image_latents",
+ ):
+ """Initialize a VAE encoder step for converting images to latent representations.
+
+ Both the input and output names are configurable so this block can be configured to process to different image
+ inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
+
+ Args:
+ input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
+ Examples: "processed_image" or "processed_control_image"
+ output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
+ Examples: "image_latents" or "control_image_latents"
+
+ Examples:
+ # Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
+
+ # Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
+ input_name="processed_control_image", output_name="control_image_latents"
+ )
+ """
+ self._image_input_name = input_name
+ self._image_latents_output_name = output_name
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(self._image_input_name, required=True),
+ InputParam("generator"),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ self._image_latents_output_name,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ dtype = components.vae.dtype
+
+ image = getattr(block_state, self._image_input_name)
+
+ # Encode image into latents
+ image_latents = encode_vae_image(
+ image=image,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ )
+
+ setattr(block_state, self._image_latents_output_name, image_latents)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ComponentSpec(
+ "control_image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam("control_image", required=True),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("generator"),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "control_image_latents",
+ type_hint=torch.Tensor,
+ description="The latents representing the control image",
+ )
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor)
+
+ device = components._execution_device
+ dtype = components.vae.dtype
+
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+
+ controlnet = unwrap_module(components.controlnet)
+ if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list):
+ block_state.control_image = [block_state.control_image]
+
+ if isinstance(controlnet, QwenImageMultiControlNetModel):
+ block_state.control_image_latents = []
+ for control_image_ in block_state.control_image:
+ control_image_ = components.control_image_processor.preprocess(
+ image=control_image_,
+ height=height,
+ width=width,
+ )
+
+ control_image_latents_ = encode_vae_image(
+ image=control_image_,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ sample_mode="sample",
+ )
+ block_state.control_image_latents.append(control_image_latents_)
+
+ elif isinstance(controlnet, QwenImageControlNetModel):
+ control_image = components.control_image_processor.preprocess(
+ image=block_state.control_image,
+ height=height,
+ width=width,
+ )
+ block_state.control_image_latents = encode_vae_image(
+ image=control_image,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ sample_mode="sample",
+ )
+
+ else:
+ raise ValueError(
+ f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}"
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py
new file mode 100644
index 0000000000..2b787c8238
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py
@@ -0,0 +1,431 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Tuple
+
+import torch
+
+from ...models import QwenImageMultiControlNetModel
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+def repeat_tensor_to_batch_size(
+ input_name: str,
+ input_tensor: torch.Tensor,
+ batch_size: int,
+ num_images_per_prompt: int = 1,
+) -> torch.Tensor:
+ """Repeat tensor elements to match the final batch size.
+
+ This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
+ by repeating each element along dimension 0.
+
+ The input tensor must have batch size 1 or batch_size. The function will:
+ - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
+ - If batch size equals batch_size: repeat each element num_images_per_prompt times
+
+ Args:
+ input_name (str): Name of the input tensor (used for error messages)
+ input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
+ batch_size (int): The base batch size (number of prompts)
+ num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
+
+ Returns:
+ torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
+
+ Raises:
+ ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
+
+ Examples:
+ tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
+ batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
+ [4, 3]
+
+ tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
+ tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
+ - shape: [4, 3]
+ """
+ # make sure input is a tensor
+ if not isinstance(input_tensor, torch.Tensor):
+ raise ValueError(f"`{input_name}` must be a tensor")
+
+ # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
+ if input_tensor.shape[0] == 1:
+ repeat_by = batch_size * num_images_per_prompt
+ elif input_tensor.shape[0] == batch_size:
+ repeat_by = num_images_per_prompt
+ else:
+ raise ValueError(
+ f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
+ )
+
+ # expand the tensor to match the batch_size * num_images_per_prompt
+ input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
+
+ return input_tensor
+
+
+def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]:
+ """Calculate image dimensions from latent tensor dimensions.
+
+ This function converts latent space dimensions to image space dimensions by multiplying the latent height and width
+ by the VAE scale factor.
+
+ Args:
+ latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
+ Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
+ vae_scale_factor (int): The scale factor used by the VAE to compress images.
+ Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
+
+ Returns:
+ Tuple[int, int]: The calculated image dimensions as (height, width)
+
+ Raises:
+ ValueError: If latents tensor doesn't have 4 or 5 dimensions
+
+ """
+ # make sure the latents are not packed
+ if latents.ndim != 4 and latents.ndim != 5:
+ raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")
+
+ latent_height, latent_width = latents.shape[-2:]
+
+ height = latent_height * vae_scale_factor
+ width = latent_width * vae_scale_factor
+
+ return height, width
+
+
+class QwenImageTextInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ summary_section = (
+ "Text input processing step that standardizes text embeddings for the pipeline.\n"
+ "This step:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
+ )
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps."
+
+ return summary_section + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
+ InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
+ InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
+ InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(
+ prompt_embeds,
+ prompt_embeds_mask,
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ ):
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None")
+
+ if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None:
+ raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`")
+
+ if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]:
+ raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
+
+ elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]:
+ raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`")
+
+ elif (
+ negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]
+ ):
+ raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ prompt_embeds=block_state.prompt_embeds,
+ prompt_embeds_mask=block_state.prompt_embeds_mask,
+ negative_prompt_embeds=block_state.negative_prompt_embeds,
+ negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask,
+ )
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len
+ )
+
+ if block_state.negative_prompt_embeds is not None:
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageInputsDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
+
+ This step handles multiple common tasks to prepare inputs for the denoising step:
+ 1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
+ 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
+
+ This is a dynamic block that allows you to configure which inputs to process.
+
+ Args:
+ image_latent_inputs (List[str], optional): Names of image latent tensors to process.
+ These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
+ list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
+ additional_batch_inputs (List[str], optional):
+ Names of additional conditional input tensors to expand batch size. These tensors will only have their
+ batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
+ Defaults to []. Examples: ["processed_mask_image"]
+
+ Examples:
+ # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
+
+ # Configure to process multiple image latent inputs
+ QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
+
+ # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
+ image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
+ )
+ """
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ # Functionality section
+ summary_section = (
+ "Input processing step that:\n"
+ " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size"
+ )
+
+ # Inputs info
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ # Add image latent inputs
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ # Add additional batch inputs
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ # 2. Patchify the image latent tensor
+ image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageControlNetInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="control_image_latents", required=True),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if isinstance(components.controlnet, QwenImageMultiControlNetModel):
+ control_image_latents = []
+ # loop through each control_image_latents
+ for i, control_image_latents_ in enumerate(block_state.control_image_latents):
+ # 1. update height/width if not provided
+ height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ # 2. pack
+ control_image_latents_ = components.pachifier.pack_latents(control_image_latents_)
+
+ # 3. repeat to match the batch size
+ control_image_latents_ = repeat_tensor_to_batch_size(
+ input_name=f"control_image_latents[{i}]",
+ input_tensor=control_image_latents_,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ control_image_latents.append(control_image_latents_)
+
+ block_state.control_image_latents = control_image_latents
+
+ else:
+ # 1. update height/width if not provided
+ height, width = calculate_dimension_from_latents(
+ block_state.control_image_latents, components.vae_scale_factor
+ )
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ # 2. pack
+ block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents)
+
+ # 3. repeat to match the batch size
+ block_state.control_image_latents = repeat_tensor_to_batch_size(
+ input_name="control_image_latents",
+ input_tensor=block_state.control_image_latents,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ block_state.control_image_latents = block_state.control_image_latents
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
new file mode 100644
index 0000000000..a01c742fcf
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
@@ -0,0 +1,841 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import (
+ QwenImageControlNetBeforeDenoiserStep,
+ QwenImageCreateMaskLatentsStep,
+ QwenImageEditRoPEInputsStep,
+ QwenImagePrepareLatentsStep,
+ QwenImagePrepareLatentsWithStrengthStep,
+ QwenImageRoPEInputsStep,
+ QwenImageSetTimestepsStep,
+ QwenImageSetTimestepsWithStrengthStep,
+)
+from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
+from .denoise import (
+ QwenImageControlNetDenoiseStep,
+ QwenImageDenoiseStep,
+ QwenImageEditDenoiseStep,
+ QwenImageEditInpaintDenoiseStep,
+ QwenImageInpaintControlNetDenoiseStep,
+ QwenImageInpaintDenoiseStep,
+ QwenImageLoopBeforeDenoiserControlNet,
+)
+from .encoders import (
+ QwenImageControlNetVaeEncoderStep,
+ QwenImageEditResizeDynamicStep,
+ QwenImageEditTextEncoderStep,
+ QwenImageInpaintProcessImagesInputStep,
+ QwenImageProcessImagesInputStep,
+ QwenImageTextEncoderStep,
+ QwenImageVaeEncoderDynamicStep,
+)
+from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
+
+
+logger = logging.get_logger(__name__)
+
+# 1. QwenImage
+
+## 1.1 QwenImage/text2image
+
+#### QwenImage/decode
+#### (standard decode step works for most tasks except for inpaint)
+QwenImageDecodeBlocks = InsertableDict(
+ [
+ ("decode", QwenImageDecoderStep()),
+ ("postprocess", QwenImageProcessImagesOutputStep()),
+ ]
+)
+
+
+class QwenImageDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageDecodeBlocks.values()
+ block_names = QwenImageDecodeBlocks.keys()
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image."
+
+
+#### QwenImage/text2image presets
+TEXT2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("input", QwenImageTextInputsStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 1.2 QwenImage/inpaint
+
+#### QwenImage/inpaint vae encoder
+QwenImageInpaintVaeEncoderBlocks = InsertableDict(
+ [
+ (
+ "preprocess",
+ QwenImageInpaintProcessImagesInputStep,
+ ), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintVaeEncoderBlocks.values()
+ block_names = QwenImageInpaintVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step is used for processing image and mask inputs for inpainting tasks. It:\n"
+ " - Resizes the image to the target size, based on `height` and `width`.\n"
+ " - Processes and updates `image` and `mask_image`.\n"
+ " - Creates `image_latents`."
+ )
+
+
+#### QwenImage/inpaint inputs
+QwenImageInpaintInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ (
+ "additional_inputs",
+ QwenImageInputsDynamicStep(
+ image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
+ ),
+ ),
+ ]
+)
+
+
+class QwenImageInpaintInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintInputBlocks.values()
+ block_names = QwenImageInpaintInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+# QwenImage/inpaint prepare latents
+QwenImageInpaintPrepareLatentsBlocks = InsertableDict(
+ [
+ ("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("create_mask_latents", QwenImageCreateMaskLatentsStep()),
+ ]
+)
+
+
+class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintPrepareLatentsBlocks.values()
+ block_names = QwenImageInpaintPrepareLatentsBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
+ " - Add noise to the image latents to create the latents input for the denoiser.\n"
+ " - Create the pachified latents `mask` based on the processedmask image.\n"
+ )
+
+
+#### QwenImage/inpaint decode
+QwenImageInpaintDecodeBlocks = InsertableDict(
+ [
+ ("decode", QwenImageDecoderStep()),
+ ("postprocess", QwenImageInpaintProcessImagesOutputStep()),
+ ]
+)
+
+
+class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintDecodeBlocks.values()
+ block_names = QwenImageInpaintDecodeBlocks.keys()
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
+
+
+#### QwenImage/inpaint presets
+INPAINT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageInpaintVaeEncoderStep()),
+ ("input", QwenImageInpaintInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageInpaintDenoiseStep()),
+ ("decode", QwenImageInpaintDecodeStep()),
+ ]
+)
+
+
+## 1.3 QwenImage/img2img
+
+#### QwenImage/img2img vae encoder
+QwenImageImg2ImgVaeEncoderBlocks = InsertableDict(
+ [
+ ("preprocess", QwenImageProcessImagesInputStep()),
+ ("encode", QwenImageVaeEncoderDynamicStep()),
+ ]
+)
+
+
+class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ block_classes = QwenImageImg2ImgVaeEncoderBlocks.values()
+ block_names = QwenImageImg2ImgVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+#### QwenImage/img2img inputs
+QwenImageImg2ImgInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
+ ]
+)
+
+
+class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageImg2ImgInputBlocks.values()
+ block_names = QwenImageImg2ImgInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+#### QwenImage/img2img presets
+IMAGE2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageImg2ImgVaeEncoderStep()),
+ ("input", QwenImageImg2ImgInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 1.4 QwenImage/controlnet
+
+#### QwenImage/controlnet presets
+CONTROLNET_BLOCKS = InsertableDict(
+ [
+ ("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image
+ ("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet
+ (
+ "controlnet_before_denoise",
+ QwenImageControlNetBeforeDenoiserStep(),
+ ), # before denoise step (after set_timesteps step)
+ (
+ "controlnet_denoise_loop_before",
+ QwenImageLoopBeforeDenoiserControlNet(),
+ ), # controlnet loop step (insert before the denoiseloop_denoiser)
+ ]
+)
+
+
+## 1.5 QwenImage/auto encoders
+
+
+#### for inpaint and img2img tasks
+class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
+ block_names = ["inpaint", "img2img"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
+ + " - if `mask_image` or `image` is not provided, step will be skipped."
+ )
+
+
+# for controlnet tasks
+class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetVaeEncoderStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
+ + " - if `control_image` is not provided, step will be skipped."
+ )
+
+
+## 1.6 QwenImage/auto inputs
+
+
+# text2image/inpaint/img2img
+class QwenImageAutoInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep]
+ block_names = ["inpaint", "img2img", "text2image"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n"
+ + " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
+ )
+
+
+# controlnet
+class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetInputsStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet input step that prepare the control_image_latents input.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - if `control_image_latents` is not provided, step will be skipped."
+ )
+
+
+## 1.7 QwenImage/auto before denoise step
+# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step
+
+# QwenImage/text2image before denoise
+QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values()
+ block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task."
+
+
+# QwenImage/inpaint before denoise
+QwenImageInpaintBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintBeforeDenoiseBlocks.values()
+ block_names = QwenImageInpaintBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
+
+
+# QwenImage/img2img before denoise
+QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values()
+ block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
+
+
+# auto before_denoise step for text2image, inpaint, img2img tasks
+class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageInpaintBeforeDenoiseStep,
+ QwenImageImg2ImgBeforeDenoiseStep,
+ QwenImageText2ImageBeforeDenoiseStep,
+ ]
+ block_names = ["inpaint", "img2img", "text2image"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n"
+ + " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
+ )
+
+
+# auto before_denoise step for controlnet tasks
+class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetBeforeDenoiserStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet before denoise step that prepare the controlnet input.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - if `control_image_latents` is not provided, step will be skipped."
+ )
+
+
+## 1.8 QwenImage/auto denoise
+
+
+# auto denoise step for controlnet tasks: works for all tasks with controlnet
+class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep]
+ block_names = ["inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet step during the denoising process. \n"
+ " This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n"
+ + " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n"
+ + " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n"
+ )
+
+
+# auto denoise step for everything: works for all tasks with or without controlnet
+class QwenImageAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageControlNetAutoDenoiseStep,
+ QwenImageInpaintDenoiseStep,
+ QwenImageDenoiseStep,
+ ]
+ block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["control_image_latents", "mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ " This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n"
+ + " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n"
+ + " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n"
+ )
+
+
+## 1.9 QwenImage/auto decode
+# auto decode step for inpaint and text2image tasks
+
+
+class QwenImageAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
+ block_names = ["inpaint_decode", "decode"]
+ block_trigger_inputs = ["mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Decode step that decode the latents into images. \n"
+ " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
+ + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
+ + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
+ )
+
+
+## 1.10 QwenImage/auto block & presets
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageAutoVaeEncoderStep()),
+ ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
+ ("input", QwenImageAutoInputStep()),
+ ("controlnet_input", QwenImageOptionalControlNetInputStep()),
+ ("before_denoise", QwenImageAutoBeforeDenoiseStep()),
+ ("controlnet_before_denoise", QwenImageOptionalControlNetBeforeDenoiseStep()),
+ ("denoise", QwenImageAutoDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ block_classes = AUTO_BLOCKS.values()
+ block_names = AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ + "- for image-to-image generation, you need to provide `image`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`"
+ )
+
+
+# 2. QwenImage-Edit
+
+## 2.1 QwenImage-Edit/edit
+
+#### QwenImage-Edit/edit vl encoder: take both image and text prompts
+QwenImageEditVLEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()),
+ ("encode", QwenImageEditTextEncoderStep()),
+ ]
+)
+
+
+class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditVLEncoderBlocks.values()
+ block_names = QwenImageEditVLEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Edit VL encoder step that encode the image an text prompts together."
+
+
+#### QwenImage-Edit/edit vae encoder
+QwenImageEditVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step
+ ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditVaeEncoderBlocks.values()
+ block_names = QwenImageEditVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that encode the image inputs into their latent representations."
+
+
+#### QwenImage-Edit/edit input
+QwenImageEditInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
+ ]
+)
+
+
+class QwenImageEditInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInputBlocks.values()
+ block_names = QwenImageEditInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the edit denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs: \n"
+ " - `image_latents`.\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+#### QwenImage/edit presets
+EDIT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditVaeEncoderStep()),
+ ("input", QwenImageEditInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 2.2 QwenImage-Edit/edit inpaint
+
+#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step
+QwenImageEditInpaintVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image
+ (
+ "preprocess",
+ QwenImageInpaintProcessImagesInputStep,
+ ), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
+ (
+ "encode",
+ QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
+ ), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInpaintVaeEncoderBlocks.values()
+ block_names = QwenImageEditInpaintVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
+ " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
+ " - process the resized image and mask image.\n"
+ " - create image latents."
+ )
+
+
+#### QwenImage-Edit/edit inpaint presets
+EDIT_INPAINT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditInpaintVaeEncoderStep()),
+ ("input", QwenImageInpaintInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditInpaintDenoiseStep()),
+ ("decode", QwenImageInpaintDecodeStep()),
+ ]
+)
+
+
+## 2.3 QwenImage-Edit/auto encoders
+
+
+class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageEditInpaintVaeEncoderStep,
+ QwenImageEditVaeEncoderStep,
+ ]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations. \n"
+ " This is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
+ + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
+ + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
+ + " - if `mask_image` or `image` is not provided, step will be skipped."
+ )
+
+
+## 2.4 QwenImage-Edit/auto inputs
+class QwenImageEditAutoInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["processed_mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the edit denoising step.\n"
+ + " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
+ + " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n"
+ + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 2.5 QwenImage-Edit/auto before denoise
+# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step
+
+#### QwenImage-Edit/edit before denoise
+QwenImageEditBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditBeforeDenoiseBlocks.values()
+ block_names = QwenImageEditBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
+
+
+#### QwenImage-Edit/edit inpaint before denoise
+QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values()
+ block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task."
+
+
+# auto before_denoise step for edit and edit_inpaint tasks
+class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditInpaintBeforeDenoiseStep,
+ QwenImageEditBeforeDenoiseStep,
+ ]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n"
+ + " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ + " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped."
+ )
+
+
+## 2.6 QwenImage-Edit/auto denoise
+
+
+class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit"
+
+ block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep]
+ block_names = ["inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ + "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n"
+ + " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 2.7 QwenImage-Edit/auto blocks & presets
+
+EDIT_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
+ ("input", QwenImageEditAutoInputStep()),
+ ("before_denoise", QwenImageEditAutoBeforeDenoiseStep()),
+ ("denoise", QwenImageEditAutoDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = EDIT_AUTO_BLOCKS.values()
+ block_names = EDIT_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
+ + "- for edit (img2img) generation, you need to provide `image`\n"
+ + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ )
+
+
+# 3. all block presets supported in QwenImage & QwenImage-Edit
+
+
+ALL_BLOCKS = {
+ "text2image": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "edit": EDIT_BLOCKS,
+ "edit_inpaint": EDIT_INPAINT_BLOCKS,
+ "inpaint": INPAINT_BLOCKS,
+ "controlnet": CONTROLNET_BLOCKS,
+ "auto": AUTO_BLOCKS,
+ "edit_auto": EDIT_AUTO_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
new file mode 100644
index 0000000000..fe9757f41b
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
@@ -0,0 +1,202 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import QwenImageLoraLoaderMixin
+from ..modular_pipeline import ModularPipeline
+
+
+class QwenImagePachifier(ConfigMixin):
+ """
+ A class to pack and unpack latents for QwenImage.
+ """
+
+ config_name = "config.json"
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ ):
+ super().__init__()
+
+ def pack_latents(self, latents):
+ if latents.ndim != 4 and latents.ndim != 5:
+ raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}")
+
+ if latents.ndim == 4:
+ latents = latents.unsqueeze(2)
+
+ batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape
+ patch_size = self.config.patch_size
+
+ if latent_height % patch_size != 0 or latent_width % patch_size != 0:
+ raise ValueError(
+ f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
+ )
+
+ latents = latents.view(
+ batch_size,
+ num_channels_latents,
+ latent_height // patch_size,
+ patch_size,
+ latent_width // patch_size,
+ patch_size,
+ )
+ latents = latents.permute(
+ 0, 2, 4, 1, 3, 5
+ ) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size
+ latents = latents.reshape(
+ batch_size,
+ (latent_height // patch_size) * (latent_width // patch_size),
+ num_channels_latents * patch_size * patch_size,
+ )
+
+ return latents
+
+ def unpack_latents(self, latents, height, width, vae_scale_factor=8):
+ if latents.ndim != 3:
+ raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
+
+ batch_size, num_patches, channels = latents.shape
+ patch_size = self.config.patch_size
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = patch_size * (int(height) // (vae_scale_factor * patch_size))
+ width = patch_size * (int(width) // (vae_scale_factor * patch_size))
+
+ latents = latents.view(
+ batch_size,
+ height // patch_size,
+ width // patch_size,
+ channels // (patch_size * patch_size),
+ patch_size,
+ patch_size,
+ )
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
+
+ return latents
+
+
+class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
+ """
+ A ModularPipeline for QwenImage.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ return 128
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if hasattr(self, "transformer") and self.transformer is not None:
+ num_channels_latents = self.transformer.config.in_channels // 4
+ return num_channels_latents
+
+ @property
+ def is_guidance_distilled(self):
+ is_guidance_distilled = False
+ if hasattr(self, "transformer") and self.transformer is not None:
+ is_guidance_distilled = self.transformer.config.guidance_embeds
+ return is_guidance_distilled
+
+ @property
+ def requires_unconditional_embeds(self):
+ requires_unconditional_embeds = False
+
+ if hasattr(self, "guider") and self.guider is not None:
+ requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
+
+ return requires_unconditional_embeds
+
+
+class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
+ """
+ A ModularPipeline for QwenImage-Edit.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ # YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step.
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ return 128
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if hasattr(self, "transformer") and self.transformer is not None:
+ num_channels_latents = self.transformer.config.in_channels // 4
+ return num_channels_latents
+
+ @property
+ def is_guidance_distilled(self):
+ is_guidance_distilled = False
+ if hasattr(self, "transformer") and self.transformer is not None:
+ is_guidance_distilled = self.transformer.config.guidance_embeds
+ return is_guidance_distilled
+
+ @property
+ def requires_unconditional_embeds(self):
+ requires_unconditional_embeds = False
+
+ if hasattr(self, "guider") and self.guider is not None:
+ requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
+
+ return requires_unconditional_embeds
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
index 0ee37f5201..e84f5cad1a 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
@@ -76,6 +76,7 @@ class StableDiffusionXLModularPipeline(
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
+ # YiYi TODO: change to num_channels_latents
@property
def num_channels_unet(self):
num_channels_unet = 4
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 25d5d213cf..8ed07a72e3 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -394,6 +394,7 @@ else:
"QwenImageInpaintPipeline",
"QwenImageEditPipeline",
"QwenImageEditInpaintPipeline",
+ "QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
]
try:
@@ -714,6 +715,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import (
+ QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index ebabf17995..880984eeb8 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -91,6 +91,14 @@ from .pag import (
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
+from .qwenimage import (
+ QwenImageControlNetPipeline,
+ QwenImageEditInpaintPipeline,
+ QwenImageEditPipeline,
+ QwenImageImg2ImgPipeline,
+ QwenImageInpaintPipeline,
+ QwenImagePipeline,
+)
from .sana import SanaPipeline
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
from .stable_diffusion import (
@@ -150,6 +158,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("cogview4-control", CogView4ControlPipeline),
+ ("qwenimage", QwenImagePipeline),
+ ("qwenimage-controlnet", QwenImageControlNetPipeline),
]
)
@@ -174,6 +184,8 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
("flux-kontext", FluxKontextPipeline),
+ ("qwenimage", QwenImageImg2ImgPipeline),
+ ("qwenimage-edit", QwenImageEditPipeline),
]
)
@@ -195,6 +207,8 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetInpaintPipeline),
("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
+ ("qwenimage", QwenImageInpaintPipeline),
+ ("qwenimage-edit", QwenImageEditInpaintPipeline),
]
)
diff --git a/src/diffusers/pipelines/qwenimage/__init__.py b/src/diffusers/pipelines/qwenimage/__init__.py
index ae5cf04dc5..36d92917fd 100644
--- a/src/diffusers/pipelines/qwenimage/__init__.py
+++ b/src/diffusers/pipelines/qwenimage/__init__.py
@@ -25,6 +25,7 @@ else:
_import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
_import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"]
+ _import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
@@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_qwenimage import QwenImagePipeline
from .pipeline_qwenimage_controlnet import QwenImageControlNetPipeline
+ from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline
from .pipeline_qwenimage_edit import QwenImageEditPipeline
from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
new file mode 100644
index 0000000000..102a813ab5
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
@@ -0,0 +1,941 @@
+# Copyright 2025 Qwen-Image Team, The InstantX Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import load_image
+ >>> from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline
+
+ >>> base_model_path = "Qwen/Qwen-Image"
+ >>> controlnet_model_path = "InstantX/Qwen-Image-ControlNet-Inpainting"
+ >>> controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)
+ >>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
+ ... base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... ).to("cuda")
+ >>> image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png"
+ ... )
+ >>> mask_image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png"
+ ... )
+ >>> prompt = "一辆绿色的出租车行驶在路上"
+ >>> result = pipe(
+ ... prompt=prompt,
+ ... control_image=image,
+ ... control_mask=mask_image,
+ ... controlnet_conditioning_scale=1.0,
+ ... width=mask_image.size[0],
+ ... height=mask_image.size[1],
+ ... true_cfg_scale=4.0,
+ ... ).images[0]
+ >>> image.save("qwenimage_controlnet_inpaint.png")
+ ```
+"""
+
+
+# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ controlnet: QwenImageControlNetModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ do_resize=True,
+ do_convert_grayscale=True,
+ do_normalize=False,
+ do_binarize=True,
+ )
+
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(self.device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_image_with_mask(
+ self,
+ image,
+ mask,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+ image = image.to(device=device, dtype=dtype) # (bsz, 3, height_ori, width_ori)
+
+ # Prepare mask
+ if isinstance(mask, torch.Tensor):
+ pass
+ else:
+ mask = self.mask_processor.preprocess(mask, height=height, width=width)
+ mask = mask.repeat_interleave(repeat_by, dim=0)
+ mask = mask.to(device=device, dtype=dtype) # (bsz, 1, height_ori, width_ori)
+
+ if image.ndim == 4:
+ image = image.unsqueeze(2)
+
+ if mask.ndim == 4:
+ mask = mask.unsqueeze(2)
+
+ # Get masked image
+ masked_image = image.clone()
+ masked_image[(mask > 0.5).repeat(1, 3, 1, 1, 1)] = -1 # (bsz, 3, 1, height_ori, width_ori)
+
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(device)
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device
+ )
+
+ # Encode to latents
+ image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
+ image_latents = (image_latents - latents_mean) * latents_std
+ image_latents = image_latents.to(dtype) # torch.Size([1, 16, 1, height_ori//8, width_ori//8])
+
+ mask = torch.nn.functional.interpolate(
+ mask, size=(image_latents.shape[-3], image_latents.shape[-2], image_latents.shape[-1])
+ )
+ mask = 1 - mask # torch.Size([1, 1, 1, height_ori//8, width_ori//8])
+
+ control_image = torch.cat(
+ [image_latents, mask], dim=1
+ ) # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8])
+
+ control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8])
+
+ # pack
+ control_image = self._pack_latents(
+ control_image,
+ batch_size=control_image.shape[0],
+ num_channels_latents=control_image.shape[2],
+ height=control_image.shape[3],
+ width=control_image.shape[4],
+ )
+
+ if do_classifier_free_guidance and not guess_mode:
+ control_image = torch.cat([control_image] * 2)
+
+ return control_image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 1.0,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_image: PipelineImageInput = None,
+ control_mask: PipelineImageInput = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 3. Prepare control image
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if isinstance(self.controlnet, QwenImageControlNetModel):
+ control_image = self.prepare_image_with_mask(
+ image=control_image,
+ mask=control_mask,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # controlnet
+ controlnet_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image.to(dtype=latents.dtype, device=device),
+ conditioning_scale=cond_scale,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ return_dict=False,
+ )
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py
index 3dd00b2ce3..2fba9986e8 100644
--- a/src/diffusers/quantizers/gguf/utils.py
+++ b/src/diffusers/quantizers/gguf/utils.py
@@ -429,8 +429,64 @@ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
+# this part from calcuis (gguf.org)
+# more info: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py
+
+
+def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None):
+ kvalues = torch.tensor(
+ [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113],
+ dtype=torch.float32,
+ device=blocks.device,
+ )
+ n_blocks = blocks.shape[0]
+ d, qs = split_block_dims(blocks, 2)
+ d = d.view(torch.float16).to(dtype)
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
+ [0, 4], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, 1, 2, 1))
+ qs = (qs & 15).reshape((n_blocks, -1)).to(torch.int64)
+ kvalues = kvalues.view(1, 1, 16)
+ qs = qs.unsqueeze(-1)
+ qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], 16), 2, qs)
+ qs = qs.squeeze(-1).to(dtype)
+ return d * qs
+
+
+def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None):
+ kvalues = torch.tensor(
+ [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113],
+ dtype=torch.float32,
+ device=blocks.device,
+ )
+ n_blocks = blocks.shape[0]
+ d, scales_h, scales_l, qs = split_block_dims(blocks, 2, 2, QK_K // 64)
+ d = d.view(torch.float16).to(dtype)
+ scales_h = scales_h.view(torch.int16)
+ scales_l = scales_l.reshape((n_blocks, -1, 1)) >> torch.tensor(
+ [0, 4], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, 1, 2))
+ scales_h = scales_h.reshape((n_blocks, 1, -1)) >> torch.tensor(
+ [2 * i for i in range(QK_K // 32)], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, -1, 1))
+ scales_l = scales_l.reshape((n_blocks, -1)) & 0x0F
+ scales_h = scales_h.reshape((n_blocks, -1)) & 0x03
+ scales = (scales_l | (scales_h << 4)) - 32
+ dl = (d * scales.to(dtype)).reshape((n_blocks, -1, 1))
+ shifts_q = torch.tensor([0, 4], device=blocks.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
+ qs = qs.reshape((n_blocks, -1, 1, 16)) >> shifts_q
+ qs = (qs & 15).reshape((n_blocks, -1, 32)).to(torch.int64)
+ kvalues = kvalues.view(1, 1, 1, 16)
+ qs = qs.unsqueeze(-1)
+ qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], qs.shape[2], 16), 3, qs)
+ qs = qs.squeeze(-1).to(dtype)
+ return (dl * qs).reshape(n_blocks, -1)
+
+
GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
dequantize_functions = {
+ gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL,
+ gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS,
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 91eefc5c10..00792fa55a 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -32,6 +32,66 @@ class FluxModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class QwenImageAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1757,6 +1817,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class QwenImageControlNetInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class QwenImageControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/pipelines/marigold/test_marigold_intrinsics.py b/tests/pipelines/marigold/test_marigold_intrinsics.py
index 3f7ab9bf6e..7db14b67ce 100644
--- a/tests/pipelines/marigold/test_marigold_intrinsics.py
+++ b/tests/pipelines/marigold/test_marigold_intrinsics.py
@@ -34,6 +34,7 @@ from diffusers import (
)
from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -416,7 +417,7 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
expected_slice: np.ndarray = None,
model_id: str = "prs-eth/marigold-iid-appearance-v1-1",
image_url: str = "https://marigoldmonodepth.github.io/images/einstein.jpg",
- atol: float = 1e-4,
+ atol: float = 1e-3,
**pipe_kwargs,
):
from_pretrained_kwargs = {}
@@ -531,11 +532,41 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.62655,
+ 0.62477,
+ 0.62161,
+ 0.62452,
+ 0.62454,
+ 0.62454,
+ 0.62255,
+ 0.62647,
+ 0.63379,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.61572,
+ 0.1377,
+ 0.61182,
+ 0.61426,
+ 0.61377,
+ 0.61426,
+ 0.61279,
+ 0.61572,
+ 0.62354,
+ ]
+ ),
+ }
+ )
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.61572, 0.61377, 0.61182, 0.61426, 0.61377, 0.61426, 0.61279, 0.61572, 0.62354]),
+ expected_slice=expected_slices.get_expectation(),
num_inference_steps=1,
processing_resolution=768,
ensemble_size=3,
@@ -545,11 +576,41 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.62988,
+ 0.62792,
+ 0.62548,
+ 0.62841,
+ 0.62792,
+ 0.62792,
+ 0.62646,
+ 0.62939,
+ 0.63721,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.61914,
+ 0.6167,
+ 0.61475,
+ 0.61719,
+ 0.61719,
+ 0.61768,
+ 0.61572,
+ 0.61914,
+ 0.62695,
+ ]
+ ),
+ }
+ )
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.61914, 0.6167, 0.61475, 0.61719, 0.61719, 0.61768, 0.61572, 0.61914, 0.62695]),
+ expected_slice=expected_slices.get_expectation(),
num_inference_steps=1,
processing_resolution=768,
ensemble_size=4,