From f50b18eec7d646bf98aef576dbb0f47ff512beaa Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 8 Sep 2025 00:27:02 -1000 Subject: [PATCH 1/5] [Modular] Qwen (#12220) * add qwen modular --- docs/source/en/api/image_processor.md | 6 + src/diffusers/__init__.py | 8 + src/diffusers/hooks/_helpers.py | 10 + src/diffusers/image_processor.py | 132 +++ src/diffusers/modular_pipelines/__init__.py | 12 + .../modular_pipelines/modular_pipeline.py | 37 +- .../modular_pipelines/qwenimage/__init__.py | 75 ++ .../qwenimage/before_denoise.py | 727 +++++++++++++++ .../modular_pipelines/qwenimage/decoders.py | 203 +++++ .../modular_pipelines/qwenimage/denoise.py | 668 ++++++++++++++ .../modular_pipelines/qwenimage/encoders.py | 857 ++++++++++++++++++ .../modular_pipelines/qwenimage/inputs.py | 431 +++++++++ .../qwenimage/modular_blocks.py | 841 +++++++++++++++++ .../qwenimage/modular_pipeline.py | 202 +++++ .../stable_diffusion_xl/modular_pipeline.py | 1 + src/diffusers/pipelines/auto_pipeline.py | 14 + .../dummy_torch_and_transformers_objects.py | 60 ++ 17 files changed, 4275 insertions(+), 9 deletions(-) create mode 100644 src/diffusers/modular_pipelines/qwenimage/__init__.py create mode 100644 src/diffusers/modular_pipelines/qwenimage/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/qwenimage/decoders.py create mode 100644 src/diffusers/modular_pipelines/qwenimage/denoise.py create mode 100644 src/diffusers/modular_pipelines/qwenimage/encoders.py create mode 100644 src/diffusers/modular_pipelines/qwenimage/inputs.py create mode 100644 src/diffusers/modular_pipelines/qwenimage/modular_blocks.py create mode 100644 src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py diff --git a/docs/source/en/api/image_processor.md b/docs/source/en/api/image_processor.md index 3e75af026d..82d1837b0b 100644 --- a/docs/source/en/api/image_processor.md +++ b/docs/source/en/api/image_processor.md @@ -20,6 +20,12 @@ All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or Nu [[autodoc]] image_processor.VaeImageProcessor +## InpaintProcessor + +The [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay. + +[[autodoc]] image_processor.InpaintProcessor + ## VaeImageProcessorLDM3D The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs. diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fa5dd6482c..4c06440172 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -385,6 +385,10 @@ else: [ "FluxAutoBlocks", "FluxModularPipeline", + "QwenImageAutoBlocks", + "QwenImageEditAutoBlocks", + "QwenImageEditModularPipeline", + "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", "WanAutoBlocks", @@ -1038,6 +1042,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .modular_pipelines import ( FluxAutoBlocks, FluxModularPipeline, + QwenImageAutoBlocks, + QwenImageEditAutoBlocks, + QwenImageEditModularPipeline, + QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, WanAutoBlocks, diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index b7a74be2e5..f6e5bdd52d 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -108,6 +108,7 @@ def _register_attention_processors_metadata(): from ..models.attention_processor import AttnProcessor2_0 from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor from ..models.transformers.transformer_flux import FluxAttnProcessor + from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0 from ..models.transformers.transformer_wan import WanAttnProcessor2_0 # AttnProcessor2_0 @@ -140,6 +141,14 @@ def _register_attention_processors_metadata(): metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor), ) + # QwenDoubleStreamAttnProcessor2 + AttentionProcessorRegistry.register( + model_class=QwenDoubleStreamAttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 + ), + ) + def _register_transformer_blocks_metadata(): from ..models.attention import BasicTransformerBlock @@ -298,4 +307,5 @@ _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___h _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states # not sure what this is yet. _skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states +_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states # fmt: on diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 6a3cf77a7d..0e3082eada 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -523,6 +523,7 @@ class VaeImageProcessor(ConfigMixin): size=(height, width), ) image = self.pt_to_numpy(image) + return image def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: @@ -838,6 +839,137 @@ class VaeImageProcessor(ConfigMixin): return image +class InpaintProcessor(ConfigMixin): + """ + Image processor for inpainting image and mask. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + vae_latent_channels: int = 4, + resample: str = "lanczos", + reducing_gap: int = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_grayscale: bool = False, + mask_do_normalize: bool = False, + mask_do_binarize: bool = True, + mask_do_convert_grayscale: bool = True, + ): + super().__init__() + + self._image_processor = VaeImageProcessor( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + resample=resample, + reducing_gap=reducing_gap, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + self._mask_processor = VaeImageProcessor( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + resample=resample, + reducing_gap=reducing_gap, + do_normalize=mask_do_normalize, + do_binarize=mask_do_binarize, + do_convert_grayscale=mask_do_convert_grayscale, + ) + + def preprocess( + self, + image: PIL.Image.Image, + mask: PIL.Image.Image = None, + height: int = None, + width: int = None, + padding_mask_crop: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Preprocess the image and mask. + """ + if mask is None and padding_mask_crop is not None: + raise ValueError("mask must be provided if padding_mask_crop is provided") + + # if mask is None, same behavior as regular image processor + if mask is None: + return self._image_processor.preprocess(image, height=height, width=width) + + if padding_mask_crop is not None: + crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + processed_image = self._image_processor.preprocess( + image, + height=height, + width=width, + crops_coords=crops_coords, + resize_mode=resize_mode, + ) + + processed_mask = self._mask_processor.preprocess( + mask, + height=height, + width=width, + resize_mode=resize_mode, + crops_coords=crops_coords, + ) + + if crops_coords is not None: + postprocessing_kwargs = { + "crops_coords": crops_coords, + "original_image": image, + "original_mask": mask, + } + else: + postprocessing_kwargs = { + "crops_coords": None, + "original_image": None, + "original_mask": None, + } + + return processed_image, processed_mask, postprocessing_kwargs + + def postprocess( + self, + image: torch.Tensor, + output_type: str = "pil", + original_image: Optional[PIL.Image.Image] = None, + original_mask: Optional[PIL.Image.Image] = None, + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> Tuple[PIL.Image.Image, PIL.Image.Image]: + """ + Postprocess the image, optionally apply mask overlay + """ + image = self._image_processor.postprocess( + image, + output_type=output_type, + ) + # optionally apply the mask overlay + if crops_coords is not None and (original_image is None or original_mask is None): + raise ValueError("original_image and original_mask must be provided if crops_coords is provided") + + elif crops_coords is not None and output_type != "pil": + raise ValueError("output_type must be 'pil' if crops_coords is provided") + + elif crops_coords is not None: + image = [ + self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image + ] + + return image + + class VaeImageProcessorLDM3D(VaeImageProcessor): """ Image processor for VAE LDM3D. diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 68d707f9e0..65c22b349b 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -47,6 +47,12 @@ else: _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] _import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"] + _import_structure["qwenimage"] = [ + "QwenImageAutoBlocks", + "QwenImageModularPipeline", + "QwenImageEditModularPipeline", + "QwenImageEditAutoBlocks", + ] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -68,6 +74,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: SequentialPipelineBlocks, ) from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam + from .qwenimage import ( + QwenImageAutoBlocks, + QwenImageEditAutoBlocks, + QwenImageEditModularPipeline, + QwenImageModularPipeline, + ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import WanAutoBlocks, WanModularPipeline else: diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c0524a1f86..78226a49b1 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -56,6 +56,8 @@ MODULAR_PIPELINE_MAPPING = OrderedDict( ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), ("wan", "WanModularPipeline"), ("flux", "FluxModularPipeline"), + ("qwenimage", "QwenImageModularPipeline"), + ("qwenimage-edit", "QwenImageEditModularPipeline"), ] ) @@ -64,6 +66,8 @@ MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict( ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"), ("WanModularPipeline", "WanAutoBlocks"), ("FluxModularPipeline", "FluxAutoBlocks"), + ("QwenImageModularPipeline", "QwenImageAutoBlocks"), + ("QwenImageEditModularPipeline", "QwenImageEditAutoBlocks"), ] ) @@ -133,8 +137,8 @@ class PipelineState: Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the intermediates dict. """ - if name in self.intermediates: - return self.intermediates[name] + if name in self.values: + return self.values[name] raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") def __repr__(self): @@ -548,8 +552,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks): def __init__(self): sub_blocks = InsertableDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - sub_blocks[block_name] = block_cls() + for block_name, block in zip(self.block_names, self.block_classes): + if inspect.isclass(block): + sub_blocks[block_name] = block() + else: + sub_blocks[block_name] = block self.sub_blocks = sub_blocks if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): raise ValueError( @@ -830,7 +837,9 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): return expected_configs @classmethod - def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": + def from_blocks_dict( + cls, blocks_dict: Dict[str, Any], description: Optional[str] = None + ) -> "SequentialPipelineBlocks": """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. Args: @@ -852,12 +861,19 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): instance.block_classes = [block.__class__ for block in sub_blocks.values()] instance.block_names = list(sub_blocks.keys()) instance.sub_blocks = sub_blocks + + if description is not None: + instance.description = description + return instance def __init__(self): sub_blocks = InsertableDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - sub_blocks[block_name] = block_cls() + for block_name, block in zip(self.block_names, self.block_classes): + if inspect.isclass(block): + sub_blocks[block_name] = block() + else: + sub_blocks[block_name] = block self.sub_blocks = sub_blocks def _get_inputs(self): @@ -1280,8 +1296,11 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): def __init__(self): sub_blocks = InsertableDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - sub_blocks[block_name] = block_cls() + for block_name, block in zip(self.block_names, self.block_classes): + if inspect.isclass(block): + sub_blocks[block_name] = block() + else: + sub_blocks[block_name] = block self.sub_blocks = sub_blocks @classmethod diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py new file mode 100644 index 0000000000..81cf515730 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py @@ -0,0 +1,75 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = ["QwenImageTextEncoderStep"] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "AUTO_BLOCKS", + "CONTROLNET_BLOCKS", + "EDIT_AUTO_BLOCKS", + "EDIT_BLOCKS", + "EDIT_INPAINT_BLOCKS", + "IMAGE2IMAGE_BLOCKS", + "INPAINT_BLOCKS", + "TEXT2IMAGE_BLOCKS", + "QwenImageAutoBlocks", + "QwenImageEditAutoBlocks", + ] + _import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .encoders import ( + QwenImageTextEncoderStep, + ) + from .modular_blocks import ( + ALL_BLOCKS, + AUTO_BLOCKS, + CONTROLNET_BLOCKS, + EDIT_AUTO_BLOCKS, + EDIT_BLOCKS, + EDIT_INPAINT_BLOCKS, + IMAGE2IMAGE_BLOCKS, + INPAINT_BLOCKS, + TEXT2IMAGE_BLOCKS, + QwenImageAutoBlocks, + QwenImageEditAutoBlocks, + ) + from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py new file mode 100644 index 0000000000..738a1e5d15 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -0,0 +1,727 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.torch_utils import randn_tensor, unwrap_module +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps +def get_timesteps(scheduler, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + + return timesteps, num_inference_steps - t_start + + +# Prepare Latents steps + + +class QwenImagePrepareLatentsStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Prepare initial random noise for the generation process" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="generator"), + InputParam( + name="batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + name="dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs, can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + ) + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # we can update the height and width here since it's used to generate the initial + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width) + if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + block_state.latents = components.pachifier.pack_latents(block_state.latents) + + self.set_block_state(state, block_state) + + return components, state + + +class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial random noised, can be generated in prepare latent step.", + ), + InputParam( + name="image_latents", + required=True, + type_hint=torch.Tensor, + description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.", + ), + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="initial_noise", + type_hint=torch.Tensor, + description="The initial random noised used for inpainting denoising.", + ), + ] + + @staticmethod + def check_inputs(image_latents, latents): + if image_latents.shape[0] != latents.shape[0]: + raise ValueError( + f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}" + ) + + if image_latents.ndim != 3: + raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + image_latents=block_state.image_latents, + latents=block_state.latents, + ) + + # prepare latent timestep + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + + # make copy of initial_noise + block_state.initial_noise = block_state.latents + + # scale noise + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + + return components, state + + +class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name="processed_mask_image", + required=True, + type_hint=torch.Tensor, + description="The processed mask to use for the inpainting process.", + ), + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="dtype", required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process." + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + + height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + block_state.mask = torch.nn.functional.interpolate( + block_state.processed_mask_image, + size=(height_latents, width_latents), + ) + + block_state.mask = block_state.mask.unsqueeze(2) + block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1) + block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype) + + block_state.mask = components.pachifier.pack_latents(block_state.mask) + + self.set_block_state(state, block_state) + + return components, state + + +# Set Timesteps steps + + +class QwenImageSetTimestepsStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="num_inference_steps", default=50), + InputParam(name="sigmas"), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The latents to use for the denoising process, used to calculate the image sequence length.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process" + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + mu = calculate_shift( + image_seq_len=block_state.latents.shape[1], + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + sigmas=sigmas, + mu=mu, + ) + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + + return components, state + + +class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="num_inference_steps", default=50), + InputParam(name="sigmas"), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The latents to use for the denoising process, used to calculate the image sequence length.", + ), + InputParam(name="strength", default=0.9), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="timesteps", + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + mu = calculate_shift( + image_seq_len=block_state.latents.shape[1], + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + sigmas=sigmas, + mu=mu, + ) + + block_state.timesteps, block_state.num_inference_steps = get_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + strength=block_state.strength, + ) + + self.set_block_state(state, block_state) + + return components, state + + +# other inputs for denoiser + +## RoPE inputs for denoiser + + +class QwenImageRoPEInputsStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="batch_size", required=True), + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="prompt_embeds_mask"), + InputParam(name="negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="img_shapes", + type_hint=List[List[Tuple[int, int, int]]], + description="The shapes of the images latents, used for RoPE calculation", + ), + OutputParam( + name="txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.img_shapes = [ + [ + ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ) + ] + * block_state.batch_size + ] + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + self.set_block_state(state, block_state) + + return components, state + + +class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="batch_size", required=True), + InputParam( + name="resized_image", required=True, type_hint=torch.Tensor, description="The resized image input" + ), + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="prompt_embeds_mask"), + InputParam(name="negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="img_shapes", + type_hint=List[List[Tuple[int, int, int]]], + description="The shapes of the images latents, used for RoPE calculation", + ), + OutputParam( + name="txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # for edit, image size can be different from the target size (height/width) + image = ( + block_state.resized_image[0] if isinstance(block_state.resized_image, list) else block_state.resized_image + ) + image_width, image_height = image.size + + block_state.img_shapes = [ + [ + ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ), + (1, image_height // components.vae_scale_factor // 2, image_width // components.vae_scale_factor // 2), + ] + ] * block_state.batch_size + + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + self.set_block_state(state, block_state) + + return components, state + + +## ControlNet inputs for denoiser +class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", QwenImageControlNetModel), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("control_image_latents", required=True), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance( + block_state.control_guidance_end, list + ): + mult = ( + len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1 + ) + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance( + block_state.controlnet_conditioning_scale, float + ): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py new file mode 100644 index 0000000000..6c82fe989e --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -0,0 +1,203 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import InpaintProcessor, VaeImageProcessor +from ...models import AutoencoderKLQwenImage +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier + + +logger = logging.get_logger(__name__) + + +class QwenImageDecoderStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that decodes the latents to images" + + @property + def expected_components(self) -> List[ComponentSpec]: + components = [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + return components + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The latents to decode, can be generated in the denoise step", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, block_state.height, block_state.width + ) + block_state.latents = block_state.latents.to(components.vae.dtype) + + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(block_state.latents.device, block_state.latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(block_state.latents.device, block_state.latents.dtype) + block_state.latents = block_state.latents / latents_std + latents_mean + block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0] + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "postprocess the generated image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("images", required=True, description="the generated image from decoders step"), + InputParam( + name="output_type", + default="pil", + type_hint=str, + description="The type of the output images, can be 'pil', 'np', 'pt'", + ), + ] + + @staticmethod + def check_inputs(output_type): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.output_type) + + block_state.images = components.image_processor.postprocess( + image=block_state.images, + output_type=block_state.output_type, + ) + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "postprocess the generated image, optional apply the mask overally to the original image.." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("images", required=True, description="the generated image from decoders step"), + InputParam( + name="output_type", + default="pil", + type_hint=str, + description="The type of the output images, can be 'pil', 'np', 'pt'", + ), + InputParam("mask_overlay_kwargs"), + ] + + @staticmethod + def check_inputs(output_type, mask_overlay_kwargs): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + if mask_overlay_kwargs and output_type != "pil": + raise ValueError("only support output_type 'pil' for mask overlay") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs) + + if block_state.mask_overlay_kwargs is None: + mask_overlay_kwargs = {} + else: + mask_overlay_kwargs = block_state.mask_overlay_kwargs + + block_state.images = components.image_mask_processor.postprocess( + image=block_state.images, + **mask_overlay_kwargs, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py new file mode 100644 index 0000000000..d0704ee6e0 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -0,0 +1,668 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline + + +logger = logging.get_logger(__name__) + + +class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # one timestep + block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) + block_state.latent_model_input = block_state.latents + return components, block_state + + +class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # one timestep + + block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1) + block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) + return components, block_state + + +class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("controlnet", QwenImageControlNetModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that runs the controlnet before the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "controlnet_conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description=( + "All conditional model inputs for the denoiser. " + "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens." + ), + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int): + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [ + c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i]) + ] + else: + controlnet_cond_scale = block_state.controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # run controlnet for the guidance batch + controlnet_block_samples = components.controlnet( + hidden_states=block_state.latent_model_input, + controlnet_cond=block_state.control_image_latents, + conditioning_scale=block_state.cond_scale, + timestep=block_state.timestep / 1000, + img_shapes=block_state.img_shapes, + encoder_hidden_states=block_state.prompt_embeds, + encoder_hidden_states_mask=block_state.prompt_embeds_mask, + txt_seq_lens=block_state.txt_seq_lens, + return_dict=False, + ) + + block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples + + return components, block_state + + +class QwenImageLoopDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", QwenImageTransformer2DModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to use for the denoising process. Can be generated in prepare_latents step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ), + InputParam( + "img_shapes", + required=True, + type_hint=List[Tuple[int, int]], + description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + guider_input_fields = { + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"), + "txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} + + # YiYi TODO: add cache context + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep / 1000, + img_shapes=block_state.img_shapes, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + **block_state.additional_cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + + # apply guidance rescale + pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True) + pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True) + block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm) + + return components, block_state + + +class QwenImageEditLoopDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", QwenImageTransformer2DModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to use for the denoising process. Can be generated in prepare_latents step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ), + InputParam( + "img_shapes", + required=True, + type_hint=List[Tuple[int, int]], + description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + guider_input_fields = { + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"), + "txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} + + # YiYi TODO: add cache context + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep / 1000, + img_shapes=block_state.img_shapes, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + **block_state.additional_cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + + pred = guider_output.pred[:, : block_state.latents.size(1)] + pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)] + + # apply guidance rescale + pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True) + pred_norm = torch.norm(pred, dim=-1, keepdim=True) + block_state.noise_pred = pred * (pred_cond_norm / pred_norm) + + return components, block_state + + +class QwenImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that updates the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that updates the latents using mask and image_latents for inpainting. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.", + ), + InputParam( + "initial_noise", + required=True, + type_hint=torch.Tensor, + description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.scale_noise( + block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise + ) + + block_state.latents = ( + 1 - block_state.mask + ) * block_state.init_latents_proper + block_state.mask * block_state.latents + + return components, block_state + + +class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def loop_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + block_state.additional_cond_kwargs = {} + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports text2image and image2image tasks for QwenImage." + ) + + +# composing the inpainting denoising loops +class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks for QwenImage." + ) + + +# composing the controlnet denoising loops +class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopBeforeDenoiserControlNet, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopBeforeDenoiserControlNet`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports text2img/img2img tasks with controlnet for QwenImage." + ) + + +# composing the controlnet denoising loops +class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopBeforeDenoiserControlNet, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = [ + "before_denoiser", + "before_denoiser_controlnet", + "denoiser", + "after_denoiser", + "after_denoiser_inpaint", + ] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopBeforeDenoiserControlNet`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks with controlnet for QwenImage." + ) + + +# composing the denoising loops +class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports QwenImage Edit." + ) + + +class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks for QwenImage Edit." + ) diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py new file mode 100644 index 0000000000..280fa6a152 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -0,0 +1,857 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Union + +import PIL +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist +from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions +from ...utils import logging +from ...utils.torch_utils import unwrap_module +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline + + +logger = logging.get_logger(__name__) + + +def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + +def get_qwen_prompt_embeds( + text_encoder, + tokenizer, + prompt: Union[str, List[str]] = None, + prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + prompt_template_encode_start_idx: int = 34, + tokenizer_max_length: int = 1024, + device: Optional[torch.device] = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = prompt_template_encode + drop_idx = prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer( + txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + + split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + + return prompt_embeds, encoder_attention_mask + + +def get_qwen_prompt_embeds_edit( + text_encoder, + processor, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", + prompt_template_encode_start_idx: int = 64, + device: Optional[torch.device] = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = prompt_template_encode + drop_idx = prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + + return prompt_embeds, encoder_attention_mask + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image +def encode_vae_image( + image: torch.Tensor, + vae: AutoencoderKLQwenImage, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, + sample_mode: str = "argmax", +): + if not isinstance(image, torch.Tensor): + raise ValueError(f"Expected image to be a tensor, got {type(image)}.") + + # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + image = image.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(vae.config.latents_std) + .view(1, latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + +class QwenImageEditResizeDynamicStep(ModularPipelineBlocks): + model_name = "qwenimage" + + def __init__(self, input_name: str = "image", output_name: str = "resized_image"): + """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. + + This block resizes an input image tensor and exposes the resized result under configurable input and output + names. Use this when you need to wire the resize step to different image fields (e.g., "image", + "control_image") + + Args: + input_name (str, optional): Name of the image field to read from the + pipeline state. Defaults to "image". + output_name (str, optional): Name of the resized image field to write + back to the pipeline state. Defaults to "resized_image". + """ + if not isinstance(input_name, str) or not isinstance(output_name, str): + raise ValueError( + f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" + ) + self._image_input_name = input_name + self._resized_image_output_name = output_name + super().__init__() + + @property + def description(self) -> str: + return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = getattr(block_state, self._image_input_name) + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + image_width, image_height = images[0].size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + + resized_images = [ + components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + for image in images + ] + + setattr(block_state, self._resized_image_output_name, resized_images) + self.set_block_state(state, block_state) + return components, state + + +class QwenImageTextEncoderStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"), + ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec( + name="prompt_template_encode", + default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + ), + ConfigSpec(name="prompt_template_encode_start_idx", default=34), + ConfigSpec(name="tokenizer_max_length", default=1024), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), + InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam( + name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024 + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The prompt embeddings", + ), + OutputParam( + name="prompt_embeds_mask", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The encoder attention mask", + ), + OutputParam( + name="negative_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The negative prompt embeddings", + ), + OutputParam( + name="negative_prompt_embeds_mask", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The negative prompt embeddings mask", + ), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt, max_sequence_length): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + device = components._execution_device + self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length) + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=block_state.prompt, + prompt_template_encode=components.config.prompt_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + tokenizer_max_length=components.config.tokenizer_max_length, + device=device, + ) + + block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length] + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length] + + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or "" + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=negative_prompt, + prompt_template_encode=components.config.prompt_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + tokenizer_max_length=components.config.tokenizer_max_length, + device=device, + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[ + :, : block_state.max_sequence_length + ] + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[ + :, : block_state.max_sequence_length + ] + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageEditTextEncoderStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec( + name="prompt_template_encode", + default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", + ), + ConfigSpec(name="prompt_template_encode_start_idx", default=64), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), + InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam( + name="resized_image", + required=True, + type_hint=torch.Tensor, + description="The image prompt to encode, should be resized using resize step", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The prompt embeddings", + ), + OutputParam( + name="prompt_embeds_mask", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The encoder attention mask", + ), + OutputParam( + name="negative_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The negative prompt embeddings", + ), + OutputParam( + name="negative_prompt_embeds_mask", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="The negative prompt embeddings mask", + ), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.prompt, block_state.negative_prompt) + + device = components._execution_device + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit( + components.text_encoder, + components.processor, + prompt=block_state.prompt, + image=block_state.resized_image, + prompt_template_encode=components.config.prompt_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + device=device, + ) + + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or "" + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( + components.text_encoder, + components.processor, + prompt=negative_prompt, + image=block_state.resized_image, + prompt_template_encode=components.config.prompt_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + device=device, + ) + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("mask_image", required=True), + InputParam("resized_image"), + InputParam("image"), + InputParam("height"), + InputParam("width"), + InputParam("padding_mask_crop"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="processed_image"), + OutputParam(name="processed_mask_image"), + OutputParam( + name="mask_overlay_kwargs", + type_hint=Dict, + description="The kwargs for the postprocess step to apply the mask overlay", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.resized_image is None and block_state.image is None: + raise ValueError("resized_image and image cannot be None at the same time") + + if block_state.resized_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + else: + width, height = block_state.resized_image[0].size + image = block_state.resized_image + + block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( + components.image_mask_processor.preprocess( + image=image, + mask=block_state.mask_image, + height=height, + width=width, + padding_mask_crop=block_state.padding_mask_crop, + ) + ) + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageProcessImagesInputStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image"), + InputParam("image"), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="processed_image"), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.resized_image is None and block_state.image is None: + raise ValueError("resized_image and image cannot be None at the same time") + + if block_state.resized_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + else: + width, height = block_state.resized_image[0].size + image = block_state.resized_image + + block_state.processed_image = components.image_processor.preprocess( + image=image, + height=height, + width=width, + ) + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks): + model_name = "qwenimage" + + def __init__( + self, + input_name: str = "processed_image", + output_name: str = "image_latents", + ): + """Initialize a VAE encoder step for converting images to latent representations. + + Both the input and output names are configurable so this block can be configured to process to different image + inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents"). + + Args: + input_name (str, optional): Name of the input image tensor. Defaults to "processed_image". + Examples: "processed_image" or "processed_control_image" + output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents". + Examples: "image_latents" or "control_image_latents" + + Examples: + # Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep() + + # Custom input/output names for control image QwenImageVaeEncoderDynamicStep( + input_name="processed_control_image", output_name="control_image_latents" + ) + """ + self._image_input_name = input_name + self._image_latents_output_name = output_name + super().__init__() + + @property + def description(self) -> str: + return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + + @property + def expected_components(self) -> List[ComponentSpec]: + components = [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ] + return components + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(self._image_input_name, required=True), + InputParam("generator"), + ] + return inputs + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + self._image_latents_output_name, + type_hint=torch.Tensor, + description="The latents representing the reference image", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + dtype = components.vae.dtype + + image = getattr(block_state, self._image_input_name) + + # Encode image into latents + image_latents = encode_vae_image( + image=image, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + + return components, state + + +class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n" + + @property + def expected_components(self) -> List[ComponentSpec]: + components = [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ComponentSpec("controlnet", QwenImageControlNetModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + return components + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam("control_image", required=True), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + ] + return inputs + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "control_image_latents", + type_hint=torch.Tensor, + description="The latents representing the control image", + ) + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor) + + device = components._execution_device + dtype = components.vae.dtype + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + controlnet = unwrap_module(components.controlnet) + if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + + if isinstance(controlnet, QwenImageMultiControlNetModel): + block_state.control_image_latents = [] + for control_image_ in block_state.control_image: + control_image_ = components.control_image_processor.preprocess( + image=control_image_, + height=height, + width=width, + ) + + control_image_latents_ = encode_vae_image( + image=control_image_, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + sample_mode="sample", + ) + block_state.control_image_latents.append(control_image_latents_) + + elif isinstance(controlnet, QwenImageControlNetModel): + control_image = components.control_image_processor.preprocess( + image=block_state.control_image, + height=height, + width=width, + ) + block_state.control_image_latents = encode_vae_image( + image=control_image, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + sample_mode="sample", + ) + + else: + raise ValueError( + f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}" + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py new file mode 100644 index 0000000000..2b787c8238 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -0,0 +1,431 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import torch + +from ...models import QwenImageMultiControlNetModel +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent space dimensions to image space dimensions by multiplying the latent height and width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress images. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + Tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + # make sure the latents are not packed + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}") + + latent_height, latent_width = latents.shape[-2:] + + height = latent_height * vae_scale_factor + width = latent_width * vae_scale_factor + + return height, width + + +class QwenImageTextInputsStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + summary_section = ( + "Text input processing step that standardizes text embeddings for the pipeline.\n" + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + # Placement guidance + placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps." + + return summary_section + placement_section + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"), + InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"), + InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"), + InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + ] + + @staticmethod + def check_inputs( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + ): + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None") + + if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None: + raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`") + + if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]: + raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`") + + elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]: + raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`") + + elif ( + negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0] + ): + raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`") + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + prompt_embeds=block_state.prompt_embeds, + prompt_embeds_mask=block_state.prompt_embeds_mask, + negative_prompt_embeds=block_state.negative_prompt_embeds, + negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask, + ) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len + ) + + self.set_block_state(state, block_state) + + return components, state + + +class QwenImageInputsDynamicStep(ModularPipelineBlocks): + model_name = "qwenimage" + + def __init__( + self, + image_latent_inputs: List[str] = ["image_latents"], + additional_batch_inputs: List[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (List[str], optional): Names of image latent tensors to process. + These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or + list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"] + additional_batch_inputs (List[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. Examples: ["processed_mask_image"] + + Examples: + # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep() + + # Configure to process multiple image latent inputs + QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"]) + + # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 2. Patchify the image latent tensor + image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class QwenImageControlNetInputsStep(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="control_image_latents", required=True), + InputParam(name="batch_size", required=True), + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="height"), + InputParam(name="width"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if isinstance(components.controlnet, QwenImageMultiControlNetModel): + control_image_latents = [] + # loop through each control_image_latents + for i, control_image_latents_ in enumerate(block_state.control_image_latents): + # 1. update height/width if not provided + height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 2. pack + control_image_latents_ = components.pachifier.pack_latents(control_image_latents_) + + # 3. repeat to match the batch size + control_image_latents_ = repeat_tensor_to_batch_size( + input_name=f"control_image_latents[{i}]", + input_tensor=control_image_latents_, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + control_image_latents.append(control_image_latents_) + + block_state.control_image_latents = control_image_latents + + else: + # 1. update height/width if not provided + height, width = calculate_dimension_from_latents( + block_state.control_image_latents, components.vae_scale_factor + ) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 2. pack + block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents) + + # 3. repeat to match the batch size + block_state.control_image_latents = repeat_tensor_to_batch_size( + input_name="control_image_latents", + input_tensor=block_state.control_image_latents, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + block_state.control_image_latents = block_state.control_image_latents + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py new file mode 100644 index 0000000000..a01c742fcf --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -0,0 +1,841 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + QwenImageControlNetBeforeDenoiserStep, + QwenImageCreateMaskLatentsStep, + QwenImageEditRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageRoPEInputsStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep +from .denoise import ( + QwenImageControlNetDenoiseStep, + QwenImageDenoiseStep, + QwenImageEditDenoiseStep, + QwenImageEditInpaintDenoiseStep, + QwenImageInpaintControlNetDenoiseStep, + QwenImageInpaintDenoiseStep, + QwenImageLoopBeforeDenoiserControlNet, +) +from .encoders import ( + QwenImageControlNetVaeEncoderStep, + QwenImageEditResizeDynamicStep, + QwenImageEditTextEncoderStep, + QwenImageInpaintProcessImagesInputStep, + QwenImageProcessImagesInputStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderDynamicStep, +) +from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep + + +logger = logging.get_logger(__name__) + +# 1. QwenImage + +## 1.1 QwenImage/text2image + +#### QwenImage/decode +#### (standard decode step works for most tasks except for inpaint) +QwenImageDecodeBlocks = InsertableDict( + [ + ("decode", QwenImageDecoderStep()), + ("postprocess", QwenImageProcessImagesOutputStep()), + ] +) + + +class QwenImageDecodeStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageDecodeBlocks.values() + block_names = QwenImageDecodeBlocks.keys() + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +#### QwenImage/text2image presets +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageTextEncoderStep()), + ("input", QwenImageTextInputsStep()), + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageRoPEInputsStep()), + ("denoise", QwenImageDenoiseStep()), + ("decode", QwenImageDecodeStep()), + ] +) + + +## 1.2 QwenImage/inpaint + +#### QwenImage/inpaint vae encoder +QwenImageInpaintVaeEncoderBlocks = InsertableDict( + [ + ( + "preprocess", + QwenImageInpaintProcessImagesInputStep, + ), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs + ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents + ] +) + + +class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageInpaintVaeEncoderBlocks.values() + block_names = QwenImageInpaintVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for inpainting tasks. It:\n" + " - Resizes the image to the target size, based on `height` and `width`.\n" + " - Processes and updates `image` and `mask_image`.\n" + " - Creates `image_latents`." + ) + + +#### QwenImage/inpaint inputs +QwenImageInpaintInputBlocks = InsertableDict( + [ + ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings + ( + "additional_inputs", + QwenImageInputsDynamicStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + ), + ), + ] +) + + +class QwenImageInpaintInputStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageInpaintInputBlocks.values() + block_names = QwenImageInpaintInputBlocks.keys() + + @property + def description(self): + return "Input step that prepares the inputs for the inpainting denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# QwenImage/inpaint prepare latents +QwenImageInpaintPrepareLatentsBlocks = InsertableDict( + [ + ("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()), + ("create_mask_latents", QwenImageCreateMaskLatentsStep()), + ] +) + + +class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageInpaintPrepareLatentsBlocks.values() + block_names = QwenImageInpaintPrepareLatentsBlocks.keys() + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the pachified latents `mask` based on the processedmask image.\n" + ) + + +#### QwenImage/inpaint decode +QwenImageInpaintDecodeBlocks = InsertableDict( + [ + ("decode", QwenImageDecoderStep()), + ("postprocess", QwenImageInpaintProcessImagesOutputStep()), + ] +) + + +class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageInpaintDecodeBlocks.values() + block_names = QwenImageInpaintDecodeBlocks.keys() + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image." + + +#### QwenImage/inpaint presets +INPAINT_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageTextEncoderStep()), + ("vae_encoder", QwenImageInpaintVaeEncoderStep()), + ("input", QwenImageInpaintInputStep()), + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), + ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), + ("prepare_rope_inputs", QwenImageRoPEInputsStep()), + ("denoise", QwenImageInpaintDenoiseStep()), + ("decode", QwenImageInpaintDecodeStep()), + ] +) + + +## 1.3 QwenImage/img2img + +#### QwenImage/img2img vae encoder +QwenImageImg2ImgVaeEncoderBlocks = InsertableDict( + [ + ("preprocess", QwenImageProcessImagesInputStep()), + ("encode", QwenImageVaeEncoderDynamicStep()), + ] +) + + +class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + + block_classes = QwenImageImg2ImgVaeEncoderBlocks.values() + block_names = QwenImageImg2ImgVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +#### QwenImage/img2img inputs +QwenImageImg2ImgInputBlocks = InsertableDict( + [ + ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings + ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])), + ] +) + + +class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageImg2ImgInputBlocks.values() + block_names = QwenImageImg2ImgInputBlocks.keys() + + @property + def description(self): + return "Input step that prepares the inputs for the img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +#### QwenImage/img2img presets +IMAGE2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageTextEncoderStep()), + ("vae_encoder", QwenImageImg2ImgVaeEncoderStep()), + ("input", QwenImageImg2ImgInputStep()), + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), + ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()), + ("prepare_rope_inputs", QwenImageRoPEInputsStep()), + ("denoise", QwenImageDenoiseStep()), + ("decode", QwenImageDecodeStep()), + ] +) + + +## 1.4 QwenImage/controlnet + +#### QwenImage/controlnet presets +CONTROLNET_BLOCKS = InsertableDict( + [ + ("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image + ("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet + ( + "controlnet_before_denoise", + QwenImageControlNetBeforeDenoiserStep(), + ), # before denoise step (after set_timesteps step) + ( + "controlnet_denoise_loop_before", + QwenImageLoopBeforeDenoiserControlNet(), + ), # controlnet loop step (insert before the denoiseloop_denoiser) + ] +) + + +## 1.5 QwenImage/auto encoders + + +#### for inpaint and img2img tasks +class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n" + + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# for controlnet tasks +class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageControlNetVaeEncoderStep] + block_names = ["controlnet"] + block_trigger_inputs = ["control_image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n" + + " - if `control_image` is not provided, step will be skipped." + ) + + +## 1.6 QwenImage/auto inputs + + +# text2image/inpaint/img2img +class QwenImageAutoInputStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep] + block_names = ["inpaint", "img2img", "text2image"] + block_trigger_inputs = ["processed_mask_image", "image_latents", None] + + @property + def description(self): + return ( + "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n" + " This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n" + + " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n" + + " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n" + + " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n" + ) + + +# controlnet +class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks): + block_classes = [QwenImageControlNetInputsStep] + block_names = ["controlnet"] + block_trigger_inputs = ["control_image_latents"] + + @property + def description(self): + return ( + "Controlnet input step that prepare the control_image_latents input.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n" + + " - if `control_image_latents` is not provided, step will be skipped." + ) + + +## 1.7 QwenImage/auto before denoise step +# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step + +# QwenImage/text2image before denoise +QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageRoPEInputsStep()), + ] +) + + +class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values() + block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task." + + +# QwenImage/inpaint before denoise +QwenImageInpaintBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), + ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), + ("prepare_rope_inputs", QwenImageRoPEInputsStep()), + ] +) + + +class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageInpaintBeforeDenoiseBlocks.values() + block_names = QwenImageInpaintBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + +# QwenImage/img2img before denoise +QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), + ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()), + ("prepare_rope_inputs", QwenImageRoPEInputsStep()), + ] +) + + +class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values() + block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + +# auto before_denoise step for text2image, inpaint, img2img tasks +class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [ + QwenImageInpaintBeforeDenoiseStep, + QwenImageImg2ImgBeforeDenoiseStep, + QwenImageText2ImageBeforeDenoiseStep, + ] + block_names = ["inpaint", "img2img", "text2image"] + block_trigger_inputs = ["processed_mask_image", "image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" + + "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n" + + " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n" + + " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n" + + " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n" + ) + + +# auto before_denoise step for controlnet tasks +class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [QwenImageControlNetBeforeDenoiserStep] + block_names = ["controlnet"] + block_trigger_inputs = ["control_image_latents"] + + @property + def description(self): + return ( + "Controlnet before denoise step that prepare the controlnet input.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n" + + " - if `control_image_latents` is not provided, step will be skipped." + ) + + +## 1.8 QwenImage/auto denoise + + +# auto denoise step for controlnet tasks: works for all tasks with controlnet +class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Controlnet step during the denoising process. \n" + " This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n" + + " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n" + + " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n" + ) + + +# auto denoise step for everything: works for all tasks with or without controlnet +class QwenImageAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + QwenImageControlNetAutoDenoiseStep, + QwenImageInpaintDenoiseStep, + QwenImageDenoiseStep, + ] + block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"] + block_trigger_inputs = ["control_image_latents", "mask", None] + + @property + def description(self): + return ( + "Denoise step that iteratively denoise the latents. \n" + " This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n" + + " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n" + + " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n" + + " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n" + ) + + +## 1.9 QwenImage/auto decode +# auto decode step for inpaint and text2image tasks + + +class QwenImageAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images. \n" + " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n" + + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n" + ) + + +## 1.10 QwenImage/auto block & presets +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageTextEncoderStep()), + ("vae_encoder", QwenImageAutoVaeEncoderStep()), + ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), + ("input", QwenImageAutoInputStep()), + ("controlnet_input", QwenImageOptionalControlNetInputStep()), + ("before_denoise", QwenImageAutoBeforeDenoiseStep()), + ("controlnet_before_denoise", QwenImageOptionalControlNetBeforeDenoiseStep()), + ("denoise", QwenImageAutoDenoiseStep()), + ("decode", QwenImageAutoDecodeStep()), + ] +) + + +class QwenImageAutoBlocks(SequentialPipelineBlocks): + model_name = "qwenimage" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n" + + "- for image-to-image generation, you need to provide `image`\n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + + "- to run the controlnet workflow, you need to provide `control_image`\n" + + "- for text-to-image generation, all you need to provide is `prompt`" + ) + + +# 2. QwenImage-Edit + +## 2.1 QwenImage-Edit/edit + +#### QwenImage-Edit/edit vl encoder: take both image and text prompts +QwenImageEditVLEncoderBlocks = InsertableDict( + [ + ("resize", QwenImageEditResizeDynamicStep()), + ("encode", QwenImageEditTextEncoderStep()), + ] +) + + +class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditVLEncoderBlocks.values() + block_names = QwenImageEditVLEncoderBlocks.keys() + + @property + def description(self) -> str: + return "QwenImage-Edit VL encoder step that encode the image an text prompts together." + + +#### QwenImage-Edit/edit vae encoder +QwenImageEditVaeEncoderBlocks = InsertableDict( + [ + ("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step + ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image + ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents + ] +) + + +class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditVaeEncoderBlocks.values() + block_names = QwenImageEditVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +#### QwenImage-Edit/edit input +QwenImageEditInputBlocks = InsertableDict( + [ + ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings + ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])), + ] +) + + +class QwenImageEditInputStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditInputBlocks.values() + block_names = QwenImageEditInputBlocks.keys() + + @property + def description(self): + return "Input step that prepares the inputs for the edit denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs: \n" + " - `image_latents`.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +#### QwenImage/edit presets +EDIT_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditVLEncoderStep()), + ("vae_encoder", QwenImageEditVaeEncoderStep()), + ("input", QwenImageEditInputStep()), + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), + ("denoise", QwenImageEditDenoiseStep()), + ("decode", QwenImageDecodeStep()), + ] +) + + +## 2.2 QwenImage-Edit/edit inpaint + +#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step +QwenImageEditInpaintVaeEncoderBlocks = InsertableDict( + [ + ("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image + ( + "preprocess", + QwenImageInpaintProcessImagesInputStep, + ), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs + ( + "encode", + QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"), + ), # processed_image -> image_latents + ] +) + + +class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditInpaintVaeEncoderBlocks.values() + block_names = QwenImageEditInpaintVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n" + " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n" + " - process the resized image and mask image.\n" + " - create image latents." + ) + + +#### QwenImage-Edit/edit inpaint presets +EDIT_INPAINT_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditVLEncoderStep()), + ("vae_encoder", QwenImageEditInpaintVaeEncoderStep()), + ("input", QwenImageInpaintInputStep()), + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), + ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), + ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), + ("denoise", QwenImageEditInpaintDenoiseStep()), + ("decode", QwenImageInpaintDecodeStep()), + ] +) + + +## 2.3 QwenImage-Edit/auto encoders + + +class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [ + QwenImageEditInpaintVaeEncoderStep, + QwenImageEditVaeEncoderStep, + ] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations. \n" + " This is an auto pipeline block that works for edit and edit_inpaint tasks.\n" + + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n" + + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n" + + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +## 2.4 QwenImage-Edit/auto inputs +class QwenImageEditAutoInputStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["processed_mask_image", "image"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit denoising step.\n" + + " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n" + + " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n" + + " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n" + + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped." + ) + + +## 2.5 QwenImage-Edit/auto before denoise +# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step + +#### QwenImage-Edit/edit before denoise +QwenImageEditBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), + ] +) + + +class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditBeforeDenoiseBlocks.values() + block_names = QwenImageEditBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." + + +#### QwenImage-Edit/edit inpaint before denoise +QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), + ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), + ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), + ] +) + + +class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values() + block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task." + + +# auto before_denoise step for edit and edit_inpaint tasks +class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintBeforeDenoiseStep, + QwenImageEditBeforeDenoiseStep, + ] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["processed_mask_image", "image_latents"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" + + "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n" + + " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n" + + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + + " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped." + ) + + +## 2.6 QwenImage-Edit/auto denoise + + +class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks): + model_name = "qwenimage-edit" + + block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["processed_mask_image", "image_latents"] + + @property + def description(self): + return ( + "Denoise step that iteratively denoise the latents. \n" + + "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n" + + " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n" + + " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n" + + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped." + ) + + +## 2.7 QwenImage-Edit/auto blocks & presets + +EDIT_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditVLEncoderStep()), + ("vae_encoder", QwenImageEditAutoVaeEncoderStep()), + ("input", QwenImageEditAutoInputStep()), + ("before_denoise", QwenImageEditAutoBeforeDenoiseStep()), + ("denoise", QwenImageEditAutoDenoiseStep()), + ("decode", QwenImageAutoDecodeStep()), + ] +) + + +class QwenImageEditAutoBlocks(SequentialPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = EDIT_AUTO_BLOCKS.values() + block_names = EDIT_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n" + + "- for edit (img2img) generation, you need to provide `image`\n" + + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + ) + + +# 3. all block presets supported in QwenImage & QwenImage-Edit + + +ALL_BLOCKS = { + "text2image": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "edit": EDIT_BLOCKS, + "edit_inpaint": EDIT_INPAINT_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "auto": AUTO_BLOCKS, + "edit_auto": EDIT_AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py new file mode 100644 index 0000000000..fe9757f41b --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py @@ -0,0 +1,202 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import QwenImageLoraLoaderMixin +from ..modular_pipeline import ModularPipeline + + +class QwenImagePachifier(ConfigMixin): + """ + A class to pack and unpack latents for QwenImage. + """ + + config_name = "config.json" + + @register_to_config + def __init__( + self, + patch_size: int = 2, + ): + super().__init__() + + def pack_latents(self, latents): + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}") + + if latents.ndim == 4: + latents = latents.unsqueeze(2) + + batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape + patch_size = self.config.patch_size + + if latent_height % patch_size != 0 or latent_width % patch_size != 0: + raise ValueError( + f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}" + ) + + latents = latents.view( + batch_size, + num_channels_latents, + latent_height // patch_size, + patch_size, + latent_width // patch_size, + patch_size, + ) + latents = latents.permute( + 0, 2, 4, 1, 3, 5 + ) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size + latents = latents.reshape( + batch_size, + (latent_height // patch_size) * (latent_width // patch_size), + num_channels_latents * patch_size * patch_size, + ) + + return latents + + def unpack_latents(self, latents, height, width, vae_scale_factor=8): + if latents.ndim != 3: + raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}") + + batch_size, num_patches, channels = latents.shape + patch_size = self.config.patch_size + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + + latents = latents.view( + batch_size, + height // patch_size, + width // patch_size, + channels // (patch_size * patch_size), + patch_size, + patch_size, + ) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width) + + return latents + + +class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): + """ + A ModularPipeline for QwenImage. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def is_guidance_distilled(self): + is_guidance_distilled = False + if hasattr(self, "transformer") and self.transformer is not None: + is_guidance_distilled = self.transformer.config.guidance_embeds + return is_guidance_distilled + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + +class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): + """ + A ModularPipeline for QwenImage-Edit. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + # YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step. + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def is_guidance_distilled(self): + is_guidance_distilled = False + if hasattr(self, "transformer") and self.transformer is not None: + is_guidance_distilled = self.transformer.config.guidance_embeds + return is_guidance_distilled + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index 0ee37f5201..e84f5cad1a 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -76,6 +76,7 @@ class StableDiffusionXLModularPipeline( vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) return vae_scale_factor + # YiYi TODO: change to num_channels_latents @property def num_channels_unet(self): num_channels_unet = 4 diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index ebabf17995..880984eeb8 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -91,6 +91,14 @@ from .pag import ( StableDiffusionXLPAGPipeline, ) from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline +from .qwenimage import ( + QwenImageControlNetPipeline, + QwenImageEditInpaintPipeline, + QwenImageEditPipeline, + QwenImageImg2ImgPipeline, + QwenImageInpaintPipeline, + QwenImagePipeline, +) from .sana import SanaPipeline from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline from .stable_diffusion import ( @@ -150,6 +158,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), ("cogview4-control", CogView4ControlPipeline), + ("qwenimage", QwenImagePipeline), + ("qwenimage-controlnet", QwenImageControlNetPipeline), ] ) @@ -174,6 +184,8 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( ("flux-controlnet", FluxControlNetImg2ImgPipeline), ("flux-control", FluxControlImg2ImgPipeline), ("flux-kontext", FluxKontextPipeline), + ("qwenimage", QwenImageImg2ImgPipeline), + ("qwenimage-edit", QwenImageEditPipeline), ] ) @@ -195,6 +207,8 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( ("flux-controlnet", FluxControlNetInpaintPipeline), ("flux-control", FluxControlInpaintPipeline), ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline), + ("qwenimage", QwenImageInpaintPipeline), + ("qwenimage-edit", QwenImageEditInpaintPipeline), ] ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 91eefc5c10..cd4d965e57 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -32,6 +32,66 @@ class FluxModularPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class QwenImageAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class QwenImageEditAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class QwenImageEditModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class QwenImageModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] From 4e36bb0d23a0450079560ac12d2858e2eb3f7e24 Mon Sep 17 00:00:00 2001 From: "Frank (Haofan) Wang" Date: Tue, 9 Sep 2025 08:59:26 +0800 Subject: [PATCH 2/5] Support ControlNet-Inpainting for Qwen-Image (#12301) * add qwen-image-cn-inpaint --------- Co-authored-by: github-actions[bot] Co-authored-by: yiyixuxu --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/qwenimage/__init__.py | 2 + .../pipeline_qwenimage_controlnet_inpaint.py | 941 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + 5 files changed, 962 insertions(+) create mode 100644 src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4c06440172..d96acc3818 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -510,6 +510,7 @@ else: "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", + "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", "QwenImageEditInpaintPipeline", "QwenImageEditPipeline", @@ -1163,6 +1164,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, + QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, QwenImageEditPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 25d5d213cf..8ed07a72e3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -394,6 +394,7 @@ else: "QwenImageInpaintPipeline", "QwenImageEditPipeline", "QwenImageEditInpaintPipeline", + "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", ] try: @@ -714,6 +715,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .qwenimage import ( + QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, QwenImageEditPipeline, diff --git a/src/diffusers/pipelines/qwenimage/__init__.py b/src/diffusers/pipelines/qwenimage/__init__.py index ae5cf04dc5..36d92917fd 100644 --- a/src/diffusers/pipelines/qwenimage/__init__.py +++ b/src/diffusers/pipelines/qwenimage/__init__.py @@ -25,6 +25,7 @@ else: _import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"] _import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"] _import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"] + _import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"] _import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"] _import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"] _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"] @@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .pipeline_qwenimage import QwenImagePipeline from .pipeline_qwenimage_controlnet import QwenImageControlNetPipeline + from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline from .pipeline_qwenimage_edit import QwenImageEditPipeline from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py new file mode 100644 index 0000000000..102a813ab5 --- /dev/null +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -0,0 +1,941 @@ +# Copyright 2025 Qwen-Image Team, The InstantX Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.utils import load_image + >>> from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline + + >>> base_model_path = "Qwen/Qwen-Image" + >>> controlnet_model_path = "InstantX/Qwen-Image-ControlNet-Inpainting" + >>> controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16) + >>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained( + ... base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ).to("cuda") + >>> image = load_image( + ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png" + ... ) + >>> mask_image = load_image( + ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png" + ... ) + >>> prompt = "一辆绿色的出租车行驶在路上" + >>> result = pipe( + ... prompt=prompt, + ... control_image=image, + ... control_mask=mask_image, + ... controlnet_conditioning_scale=1.0, + ... width=mask_image.size[0], + ... height=mask_image.size[1], + ... true_cfg_scale=4.0, + ... ).images[0] + >>> image.save("qwenimage_controlnet_inpaint.png") + ``` +""" + + +# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + controlnet: QwenImageControlNetModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + do_resize=True, + do_convert_grayscale=True, + do_normalize=False, + do_binarize=True, + ) + + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_image_with_mask( + self, + image, + mask, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) # (bsz, 3, height_ori, width_ori) + + # Prepare mask + if isinstance(mask, torch.Tensor): + pass + else: + mask = self.mask_processor.preprocess(mask, height=height, width=width) + mask = mask.repeat_interleave(repeat_by, dim=0) + mask = mask.to(device=device, dtype=dtype) # (bsz, 1, height_ori, width_ori) + + if image.ndim == 4: + image = image.unsqueeze(2) + + if mask.ndim == 4: + mask = mask.unsqueeze(2) + + # Get masked image + masked_image = image.clone() + masked_image[(mask > 0.5).repeat(1, 3, 1, 1, 1)] = -1 # (bsz, 3, 1, height_ori, width_ori) + + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(device) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device + ) + + # Encode to latents + image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample() + image_latents = (image_latents - latents_mean) * latents_std + image_latents = image_latents.to(dtype) # torch.Size([1, 16, 1, height_ori//8, width_ori//8]) + + mask = torch.nn.functional.interpolate( + mask, size=(image_latents.shape[-3], image_latents.shape[-2], image_latents.shape[-1]) + ) + mask = 1 - mask # torch.Size([1, 1, 1, height_ori//8, width_ori//8]) + + control_image = torch.cat( + [image_latents, mask], dim=1 + ) # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8]) + + control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8]) + + # pack + control_image = self._pack_latents( + control_image, + batch_size=control_image.shape[0], + num_channels_latents=control_image.shape[2], + height=control_image.shape[3], + width=control_image.shape[4], + ) + + if do_classifier_free_guidance and not guess_mode: + control_image = torch.cat([control_image] * 2) + + return control_image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + control_image: PipelineImageInput = None, + control_mask: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 3. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, QwenImageControlNetModel): + control_image = self.prepare_image_with_mask( + image=control_image, + mask=control_mask, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet + controlnet_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image.to(dtype=latents.dtype, device=device), + conditioning_scale=cond_scale, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), + return_dict=False, + ) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), + controlnet_block_samples=controlnet_block_samples, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), + controlnet_block_samples=controlnet_block_samples, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index cd4d965e57..00792fa55a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1817,6 +1817,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class QwenImageControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImageControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From c222570a9b47901266fecf34222f540870c3bb1b Mon Sep 17 00:00:00 2001 From: Leo Jiang Date: Tue, 9 Sep 2025 15:28:08 +0800 Subject: [PATCH 3/5] DeepSpeed adaption for flux-kontext (#12240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: J石页 Co-authored-by: Sayak Paul --- .../train_dreambooth_lora_flux_kontext.py | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 87e0d2c29e..03c05a05e0 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -29,8 +29,9 @@ from pathlib import Path import numpy as np import torch import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger +from accelerate.state import AcceleratorState from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib @@ -1222,6 +1223,9 @@ def main(args): kwargs_handlers=[kwargs], ) + if accelerator.distributed_type == DistributedType.DEEPSPEED: + AcceleratorState().deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + # Disable AMP for MPS. if torch.backends.mps.is_available(): accelerator.native_amp = False @@ -1438,17 +1442,20 @@ def main(args): text_encoder_one_lora_layers_to_save = None modules_to_save = {} for model in models: - if isinstance(model, type(unwrap_model(transformer))): + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model) modules_to_save["transformer"] = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): + model = unwrap_model(model) text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) modules_to_save["text_encoder"] = model else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() FluxKontextPipeline.save_lora_weights( output_dir, @@ -1461,15 +1468,25 @@ def main(args): transformer_ = None text_encoder_one_ = None - while len(models) > 0: - model = models.pop() + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): + text_encoder_one_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + else: + transformer_ = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + text_encoder_one_ = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder" + ) lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir) @@ -2069,7 +2086,7 @@ def main(args): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: From 28106fcac4fd13e7ced5c9eb6803f107e804a08f Mon Sep 17 00:00:00 2001 From: calcuis <113646141+calcuis@users.noreply.github.com> Date: Tue, 9 Sep 2025 04:40:21 -0700 Subject: [PATCH 4/5] gguf new quant type support (with demo) (#12076) * Update utils.py not perfect but works engine: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py inference example(s): https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/k6.py https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/k5.py gguf file sample(s): https://huggingface.co/calcuis/kontext-gguf/tree/main https://huggingface.co/calcuis/krea-gguf/tree/main * Apply style fixes --------- Co-authored-by: github-actions[bot] --- src/diffusers/quantizers/gguf/utils.py | 56 ++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 3dd00b2ce3..2fba9986e8 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -429,8 +429,64 @@ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) +# this part from calcuis (gguf.org) +# more info: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py + + +def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None): + kvalues = torch.tensor( + [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], + dtype=torch.float32, + device=blocks.device, + ) + n_blocks = blocks.shape[0] + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=blocks.device, dtype=torch.uint8 + ).reshape((1, 1, 2, 1)) + qs = (qs & 15).reshape((n_blocks, -1)).to(torch.int64) + kvalues = kvalues.view(1, 1, 16) + qs = qs.unsqueeze(-1) + qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], 16), 2, qs) + qs = qs.squeeze(-1).to(dtype) + return d * qs + + +def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None): + kvalues = torch.tensor( + [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], + dtype=torch.float32, + device=blocks.device, + ) + n_blocks = blocks.shape[0] + d, scales_h, scales_l, qs = split_block_dims(blocks, 2, 2, QK_K // 64) + d = d.view(torch.float16).to(dtype) + scales_h = scales_h.view(torch.int16) + scales_l = scales_l.reshape((n_blocks, -1, 1)) >> torch.tensor( + [0, 4], device=blocks.device, dtype=torch.uint8 + ).reshape((1, 1, 2)) + scales_h = scales_h.reshape((n_blocks, 1, -1)) >> torch.tensor( + [2 * i for i in range(QK_K // 32)], device=blocks.device, dtype=torch.uint8 + ).reshape((1, -1, 1)) + scales_l = scales_l.reshape((n_blocks, -1)) & 0x0F + scales_h = scales_h.reshape((n_blocks, -1)) & 0x03 + scales = (scales_l | (scales_h << 4)) - 32 + dl = (d * scales.to(dtype)).reshape((n_blocks, -1, 1)) + shifts_q = torch.tensor([0, 4], device=blocks.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + qs = qs.reshape((n_blocks, -1, 1, 16)) >> shifts_q + qs = (qs & 15).reshape((n_blocks, -1, 32)).to(torch.int64) + kvalues = kvalues.view(1, 1, 1, 16) + qs = qs.unsqueeze(-1) + qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], qs.shape[2], 16), 3, qs) + qs = qs.squeeze(-1).to(dtype) + return (dl * qs).reshape(n_blocks, -1) + + GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES dequantize_functions = { + gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL, + gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS, gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, From 4067d6c4b64f2b606f9806d4a8b15d5fd5cbea1e Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Wed, 10 Sep 2025 05:36:03 +0800 Subject: [PATCH 5/5] adjust criteria for marigold-intrinsics example on XPU (#12290) adjust criteria for XPU Signed-off-by: Liu, Kaixuan Co-authored-by: Aryan --- .../marigold/test_marigold_intrinsics.py | 67 ++++++++++++++++++- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/marigold/test_marigold_intrinsics.py b/tests/pipelines/marigold/test_marigold_intrinsics.py index 3f7ab9bf6e..7db14b67ce 100644 --- a/tests/pipelines/marigold/test_marigold_intrinsics.py +++ b/tests/pipelines/marigold/test_marigold_intrinsics.py @@ -34,6 +34,7 @@ from diffusers import ( ) from ...testing_utils import ( + Expectations, backend_empty_cache, enable_full_determinism, floats_tensor, @@ -416,7 +417,7 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase): expected_slice: np.ndarray = None, model_id: str = "prs-eth/marigold-iid-appearance-v1-1", image_url: str = "https://marigoldmonodepth.github.io/images/einstein.jpg", - atol: float = 1e-4, + atol: float = 1e-3, **pipe_kwargs, ): from_pretrained_kwargs = {} @@ -531,11 +532,41 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase): ) def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self): + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + 0.62655, + 0.62477, + 0.62161, + 0.62452, + 0.62454, + 0.62454, + 0.62255, + 0.62647, + 0.63379, + ] + ), + ("cuda", 7): np.array( + [ + 0.61572, + 0.1377, + 0.61182, + 0.61426, + 0.61377, + 0.61426, + 0.61279, + 0.61572, + 0.62354, + ] + ), + } + ) self._test_marigold_intrinsics( is_fp16=True, device=torch_device, generator_seed=0, - expected_slice=np.array([0.61572, 0.61377, 0.61182, 0.61426, 0.61377, 0.61426, 0.61279, 0.61572, 0.62354]), + expected_slice=expected_slices.get_expectation(), num_inference_steps=1, processing_resolution=768, ensemble_size=3, @@ -545,11 +576,41 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase): ) def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self): + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + 0.62988, + 0.62792, + 0.62548, + 0.62841, + 0.62792, + 0.62792, + 0.62646, + 0.62939, + 0.63721, + ] + ), + ("cuda", 7): np.array( + [ + 0.61914, + 0.6167, + 0.61475, + 0.61719, + 0.61719, + 0.61768, + 0.61572, + 0.61914, + 0.62695, + ] + ), + } + ) self._test_marigold_intrinsics( is_fp16=True, device=torch_device, generator_seed=0, - expected_slice=np.array([0.61914, 0.6167, 0.61475, 0.61719, 0.61719, 0.61768, 0.61572, 0.61914, 0.62695]), + expected_slice=expected_slices.get_expectation(), num_inference_steps=1, processing_resolution=768, ensemble_size=4,