mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -20,6 +20,12 @@ All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or Nu
|
||||
|
||||
[[autodoc]] image_processor.VaeImageProcessor
|
||||
|
||||
## InpaintProcessor
|
||||
|
||||
The [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay.
|
||||
|
||||
[[autodoc]] image_processor.InpaintProcessor
|
||||
|
||||
## VaeImageProcessorLDM3D
|
||||
|
||||
The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.
|
||||
|
||||
@@ -385,6 +385,10 @@ else:
|
||||
[
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageEditAutoBlocks",
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"WanAutoBlocks",
|
||||
@@ -1038,6 +1042,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .modular_pipelines import (
|
||||
FluxAutoBlocks,
|
||||
FluxModularPipeline,
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
WanAutoBlocks,
|
||||
|
||||
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
||||
from ..models.transformers.transformer_flux import FluxAttnProcessor
|
||||
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
||||
|
||||
# AttnProcessor2_0
|
||||
@@ -140,6 +141,14 @@ def _register_attention_processors_metadata():
|
||||
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
|
||||
)
|
||||
|
||||
# QwenDoubleStreamAttnProcessor2
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=QwenDoubleStreamAttnProcessor2_0,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
@@ -298,4 +307,5 @@ _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___h
|
||||
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
# not sure what this is yet.
|
||||
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
@@ -523,6 +523,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
size=(height, width),
|
||||
)
|
||||
image = self.pt_to_numpy(image)
|
||||
|
||||
return image
|
||||
|
||||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
@@ -838,6 +839,137 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return image
|
||||
|
||||
|
||||
class InpaintProcessor(ConfigMixin):
|
||||
"""
|
||||
Image processor for inpainting image and mask.
|
||||
"""
|
||||
|
||||
config_name = CONFIG_NAME
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
vae_scale_factor: int = 8,
|
||||
vae_latent_channels: int = 4,
|
||||
resample: str = "lanczos",
|
||||
reducing_gap: int = None,
|
||||
do_normalize: bool = True,
|
||||
do_binarize: bool = False,
|
||||
do_convert_grayscale: bool = False,
|
||||
mask_do_normalize: bool = False,
|
||||
mask_do_binarize: bool = True,
|
||||
mask_do_convert_grayscale: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._image_processor = VaeImageProcessor(
|
||||
do_resize=do_resize,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
vae_latent_channels=vae_latent_channels,
|
||||
resample=resample,
|
||||
reducing_gap=reducing_gap,
|
||||
do_normalize=do_normalize,
|
||||
do_binarize=do_binarize,
|
||||
do_convert_grayscale=do_convert_grayscale,
|
||||
)
|
||||
self._mask_processor = VaeImageProcessor(
|
||||
do_resize=do_resize,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
vae_latent_channels=vae_latent_channels,
|
||||
resample=resample,
|
||||
reducing_gap=reducing_gap,
|
||||
do_normalize=mask_do_normalize,
|
||||
do_binarize=mask_do_binarize,
|
||||
do_convert_grayscale=mask_do_convert_grayscale,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
mask: PIL.Image.Image = None,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preprocess the image and mask.
|
||||
"""
|
||||
if mask is None and padding_mask_crop is not None:
|
||||
raise ValueError("mask must be provided if padding_mask_crop is provided")
|
||||
|
||||
# if mask is None, same behavior as regular image processor
|
||||
if mask is None:
|
||||
return self._image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
processed_image = self._image_processor.preprocess(
|
||||
image,
|
||||
height=height,
|
||||
width=width,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
processed_mask = self._mask_processor.preprocess(
|
||||
mask,
|
||||
height=height,
|
||||
width=width,
|
||||
resize_mode=resize_mode,
|
||||
crops_coords=crops_coords,
|
||||
)
|
||||
|
||||
if crops_coords is not None:
|
||||
postprocessing_kwargs = {
|
||||
"crops_coords": crops_coords,
|
||||
"original_image": image,
|
||||
"original_mask": mask,
|
||||
}
|
||||
else:
|
||||
postprocessing_kwargs = {
|
||||
"crops_coords": None,
|
||||
"original_image": None,
|
||||
"original_mask": None,
|
||||
}
|
||||
|
||||
return processed_image, processed_mask, postprocessing_kwargs
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
output_type: str = "pil",
|
||||
original_image: Optional[PIL.Image.Image] = None,
|
||||
original_mask: Optional[PIL.Image.Image] = None,
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
|
||||
"""
|
||||
Postprocess the image, optionally apply mask overlay
|
||||
"""
|
||||
image = self._image_processor.postprocess(
|
||||
image,
|
||||
output_type=output_type,
|
||||
)
|
||||
# optionally apply the mask overlay
|
||||
if crops_coords is not None and (original_image is None or original_mask is None):
|
||||
raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
|
||||
|
||||
elif crops_coords is not None and output_type != "pil":
|
||||
raise ValueError("output_type must be 'pil' if crops_coords is provided")
|
||||
|
||||
elif crops_coords is not None:
|
||||
image = [
|
||||
self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
|
||||
]
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
"""
|
||||
Image processor for VAE LDM3D.
|
||||
|
||||
@@ -47,6 +47,12 @@ else:
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageModularPipeline",
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageEditAutoBlocks",
|
||||
]
|
||||
_import_structure["components_manager"] = ["ComponentsManager"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -68,6 +74,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
|
||||
from .qwenimage import (
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import WanAutoBlocks, WanModularPipeline
|
||||
else:
|
||||
|
||||
@@ -56,6 +56,8 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||
("wan", "WanModularPipeline"),
|
||||
("flux", "FluxModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -64,6 +66,8 @@ MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
|
||||
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
|
||||
("WanModularPipeline", "WanAutoBlocks"),
|
||||
("FluxModularPipeline", "FluxAutoBlocks"),
|
||||
("QwenImageModularPipeline", "QwenImageAutoBlocks"),
|
||||
("QwenImageEditModularPipeline", "QwenImageEditAutoBlocks"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -133,8 +137,8 @@ class PipelineState:
|
||||
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
|
||||
intermediates dict.
|
||||
"""
|
||||
if name in self.intermediates:
|
||||
return self.intermediates[name]
|
||||
if name in self.values:
|
||||
return self.values[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __repr__(self):
|
||||
@@ -548,8 +552,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
for block_name, block_cls in zip(self.block_names, self.block_classes):
|
||||
sub_blocks[block_name] = block_cls()
|
||||
for block_name, block in zip(self.block_names, self.block_classes):
|
||||
if inspect.isclass(block):
|
||||
sub_blocks[block_name] = block()
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
raise ValueError(
|
||||
@@ -830,7 +837,9 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
return expected_configs
|
||||
|
||||
@classmethod
|
||||
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks":
|
||||
def from_blocks_dict(
|
||||
cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
|
||||
) -> "SequentialPipelineBlocks":
|
||||
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
|
||||
|
||||
Args:
|
||||
@@ -852,12 +861,19 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
instance.block_classes = [block.__class__ for block in sub_blocks.values()]
|
||||
instance.block_names = list(sub_blocks.keys())
|
||||
instance.sub_blocks = sub_blocks
|
||||
|
||||
if description is not None:
|
||||
instance.description = description
|
||||
|
||||
return instance
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
for block_name, block_cls in zip(self.block_names, self.block_classes):
|
||||
sub_blocks[block_name] = block_cls()
|
||||
for block_name, block in zip(self.block_names, self.block_classes):
|
||||
if inspect.isclass(block):
|
||||
sub_blocks[block_name] = block()
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
def _get_inputs(self):
|
||||
@@ -1280,8 +1296,11 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
for block_name, block_cls in zip(self.block_names, self.block_classes):
|
||||
sub_blocks[block_name] = block_cls()
|
||||
for block_name, block in zip(self.block_names, self.block_classes):
|
||||
if inspect.isclass(block):
|
||||
sub_blocks[block_name] = block()
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
@classmethod
|
||||
|
||||
75
src/diffusers/modular_pipelines/qwenimage/__init__.py
Normal file
75
src/diffusers/modular_pipelines/qwenimage/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = ["QwenImageTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"CONTROLNET_BLOCKS",
|
||||
"EDIT_AUTO_BLOCKS",
|
||||
"EDIT_BLOCKS",
|
||||
"EDIT_INPAINT_BLOCKS",
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageEditAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .encoders import (
|
||||
QwenImageTextEncoderStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
CONTROLNET_BLOCKS,
|
||||
EDIT_AUTO_BLOCKS,
|
||||
EDIT_BLOCKS,
|
||||
EDIT_INPAINT_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
727
src/diffusers/modular_pipelines/qwenimage/before_denoise.py
Normal file
727
src/diffusers/modular_pipelines/qwenimage/before_denoise.py
Normal file
@@ -0,0 +1,727 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(scheduler, num_inference_steps, strength):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
if hasattr(scheduler, "set_begin_index"):
|
||||
scheduler.set_begin_index(t_start * scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
# Prepare Latents steps
|
||||
|
||||
|
||||
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare initial random noise for the generation process"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
vae_scale_factor=components.vae_scale_factor,
|
||||
)
|
||||
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
# we can update the height and width here since it's used to generate the initial
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised, can be generated in prepare latent step.",
|
||||
),
|
||||
InputParam(
|
||||
name="image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="initial_noise",
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial random noised used for inpainting denoising.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(image_latents, latents):
|
||||
if image_latents.shape[0] != latents.shape[0]:
|
||||
raise ValueError(
|
||||
f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
|
||||
)
|
||||
|
||||
if image_latents.ndim != 3:
|
||||
raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
image_latents=block_state.image_latents,
|
||||
latents=block_state.latents,
|
||||
)
|
||||
|
||||
# prepare latent timestep
|
||||
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
|
||||
|
||||
# make copy of initial_noise
|
||||
block_state.initial_noise = block_state.latents
|
||||
|
||||
# scale noise
|
||||
block_state.latents = components.scheduler.scale_noise(
|
||||
block_state.image_latents, latent_timestep, block_state.latents
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name="processed_mask_image",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The processed mask to use for the inpainting process.",
|
||||
),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="dtype", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
|
||||
height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
|
||||
width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
|
||||
block_state.mask = torch.nn.functional.interpolate(
|
||||
block_state.processed_mask_image,
|
||||
size=(height_latents, width_latents),
|
||||
)
|
||||
|
||||
block_state.mask = block_state.mask.unsqueeze(2)
|
||||
block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1)
|
||||
block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype)
|
||||
|
||||
block_state.mask = components.pachifier.pack_latents(block_state.mask)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# Set Timesteps steps
|
||||
|
||||
|
||||
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
sigmas = (
|
||||
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
|
||||
if block_state.sigmas is None
|
||||
else block_state.sigmas
|
||||
)
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len=block_state.latents.shape[1],
|
||||
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
|
||||
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
base_shift=components.scheduler.config.get("base_shift", 0.5),
|
||||
max_shift=components.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
scheduler=components.scheduler,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
device=device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_inference_steps", default=50),
|
||||
InputParam(name="sigmas"),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process, used to calculate the image sequence length.",
|
||||
),
|
||||
InputParam(name="strength", default=0.9),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="timesteps",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
sigmas = (
|
||||
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
|
||||
if block_state.sigmas is None
|
||||
else block_state.sigmas
|
||||
)
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len=block_state.latents.shape[1],
|
||||
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
|
||||
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
base_shift=components.scheduler.config.get("base_shift", 0.5),
|
||||
max_shift=components.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
scheduler=components.scheduler,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
device=device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = get_timesteps(
|
||||
scheduler=components.scheduler,
|
||||
num_inference_steps=block_state.num_inference_steps,
|
||||
strength=block_state.strength,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# other inputs for denoiser
|
||||
|
||||
## RoPE inputs for denoiser
|
||||
|
||||
|
||||
class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(
|
||||
1,
|
||||
block_state.height // components.vae_scale_factor // 2,
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
)
|
||||
]
|
||||
* block_state.batch_size
|
||||
]
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(
|
||||
name="resized_image", required=True, type_hint=torch.Tensor, description="The resized image input"
|
||||
),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# for edit, image size can be different from the target size (height/width)
|
||||
image = (
|
||||
block_state.resized_image[0] if isinstance(block_state.resized_image, list) else block_state.resized_image
|
||||
)
|
||||
image_width, image_height = image.size
|
||||
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(
|
||||
1,
|
||||
block_state.height // components.vae_scale_factor // 2,
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
),
|
||||
(1, image_height // components.vae_scale_factor // 2, image_width // components.vae_scale_factor // 2),
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
## ControlNet inputs for denoiser
|
||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("controlnet", QwenImageControlNetModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("control_guidance_start", default=0.0),
|
||||
InputParam("control_guidance_end", default=1.0),
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("control_image_latents", required=True),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
controlnet = unwrap_module(components.controlnet)
|
||||
|
||||
# control_guidance_start/control_guidance_end (align format)
|
||||
if not isinstance(block_state.control_guidance_start, list) and isinstance(
|
||||
block_state.control_guidance_end, list
|
||||
):
|
||||
block_state.control_guidance_start = len(block_state.control_guidance_end) * [
|
||||
block_state.control_guidance_start
|
||||
]
|
||||
elif not isinstance(block_state.control_guidance_end, list) and isinstance(
|
||||
block_state.control_guidance_start, list
|
||||
):
|
||||
block_state.control_guidance_end = len(block_state.control_guidance_start) * [
|
||||
block_state.control_guidance_end
|
||||
]
|
||||
elif not isinstance(block_state.control_guidance_start, list) and not isinstance(
|
||||
block_state.control_guidance_end, list
|
||||
):
|
||||
mult = (
|
||||
len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1
|
||||
)
|
||||
block_state.control_guidance_start, block_state.control_guidance_end = (
|
||||
mult * [block_state.control_guidance_start],
|
||||
mult * [block_state.control_guidance_end],
|
||||
)
|
||||
|
||||
# controlnet_conditioning_scale (align format)
|
||||
if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance(
|
||||
block_state.controlnet_conditioning_scale, float
|
||||
):
|
||||
block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult
|
||||
|
||||
# controlnet_keep
|
||||
block_state.controlnet_keep = []
|
||||
for i in range(len(block_state.timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
|
||||
for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
|
||||
]
|
||||
block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
203
src/diffusers/modular_pipelines/qwenimage/decoders.py
Normal file
203
src/diffusers/modular_pipelines/qwenimage/decoders.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import InpaintProcessor, VaeImageProcessor
|
||||
from ...models import AutoencoderKLQwenImage
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the latents to images"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(
|
||||
name="latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to decode, can be generated in the denoise step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, block_state.height, block_state.width
|
||||
)
|
||||
block_state.latents = block_state.latents.to(components.vae.dtype)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean)
|
||||
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||
.to(block_state.latents.device, block_state.latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(block_state.latents.device, block_state.latents.dtype)
|
||||
block_state.latents = block_state.latents / latents_std + latents_mean
|
||||
block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0]
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "postprocess the generated image"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(output_type):
|
||||
if output_type not in ["pil", "np", "pt"]:
|
||||
raise ValueError(f"Invalid output_type: {output_type}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.output_type)
|
||||
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
image=block_state.images,
|
||||
output_type=block_state.output_type,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "postprocess the generated image, optional apply the mask overally to the original image.."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_mask_processor",
|
||||
InpaintProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("images", required=True, description="the generated image from decoders step"),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
InputParam("mask_overlay_kwargs"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(output_type, mask_overlay_kwargs):
|
||||
if output_type not in ["pil", "np", "pt"]:
|
||||
raise ValueError(f"Invalid output_type: {output_type}")
|
||||
|
||||
if mask_overlay_kwargs and output_type != "pil":
|
||||
raise ValueError("only support output_type 'pil' for mask overlay")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs)
|
||||
|
||||
if block_state.mask_overlay_kwargs is None:
|
||||
mask_overlay_kwargs = {}
|
||||
else:
|
||||
mask_overlay_kwargs = block_state.mask_overlay_kwargs
|
||||
|
||||
block_state.images = components.image_mask_processor.postprocess(
|
||||
image=block_state.images,
|
||||
**mask_overlay_kwargs,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
668
src/diffusers/modular_pipelines/qwenimage/denoise.py
Normal file
668
src/diffusers/modular_pipelines/qwenimage/denoise.py
Normal file
@@ -0,0 +1,668 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
# one timestep
|
||||
block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
|
||||
block_state.latent_model_input = block_state.latents
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
# one timestep
|
||||
|
||||
block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1)
|
||||
block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("controlnet", QwenImageControlNetModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that runs the controlnet before the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"control_image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=float,
|
||||
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"controlnet_keep",
|
||||
required=True,
|
||||
type_hint=List[float],
|
||||
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs for the denoiser. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int):
|
||||
# cond_scale for the timestep (controlnet input)
|
||||
if isinstance(block_state.controlnet_keep[i], list):
|
||||
block_state.cond_scale = [
|
||||
c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])
|
||||
]
|
||||
else:
|
||||
controlnet_cond_scale = block_state.controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
|
||||
|
||||
# run controlnet for the guidance batch
|
||||
controlnet_block_samples = components.controlnet(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
controlnet_cond=block_state.control_image_latents,
|
||||
conditioning_scale=block_state.cond_scale,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
|
||||
txt_seq_lens=block_state.txt_seq_lens,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that denoise the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", QwenImageTransformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
type_hint=List[Tuple[int, int]],
|
||||
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
|
||||
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
|
||||
# YiYi TODO: add cache context
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
**block_state.additional_cond_kwargs,
|
||||
)[0]
|
||||
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
guider_output = components.guider(guider_state)
|
||||
|
||||
# apply guidance rescale
|
||||
pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True)
|
||||
pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True)
|
||||
block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that denoise the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", QwenImageTransformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
),
|
||||
InputParam(
|
||||
"img_shapes",
|
||||
required=True,
|
||||
type_hint=List[Tuple[int, int]],
|
||||
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
|
||||
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||
|
||||
# YiYi TODO: add cache context
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
**block_state.additional_cond_kwargs,
|
||||
)[0]
|
||||
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
guider_output = components.guider(guider_state)
|
||||
|
||||
pred = guider_output.pred[:, : block_state.latents.size(1)]
|
||||
pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)]
|
||||
|
||||
# apply guidance rescale
|
||||
pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True)
|
||||
pred_norm = torch.norm(pred, dim=-1, keepdim=True)
|
||||
block_state.noise_pred = pred * (pred_cond_norm / pred_norm)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that updates the latents. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
latents_dtype = block_state.latents.dtype
|
||||
block_state.latents = components.scheduler.step(
|
||||
block_state.noise_pred,
|
||||
t,
|
||||
block_state.latents,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if block_state.latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
block_state.latents = block_state.latents.to(latents_dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that updates the latents using mask and image_latents for inpainting. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"mask",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"initial_noise",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.init_latents_proper = block_state.image_latents
|
||||
if i < len(block_state.timesteps) - 1:
|
||||
block_state.noise_timestep = block_state.timesteps[i + 1]
|
||||
block_state.init_latents_proper = components.scheduler.scale_noise(
|
||||
block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise
|
||||
)
|
||||
|
||||
block_state.latents = (
|
||||
1 - block_state.mask
|
||||
) * block_state.init_latents_proper + block_state.mask * block_state.latents
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pipeline block that iteratively denoise the latents over `timesteps`. "
|
||||
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
|
||||
)
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
block_state.additional_cond_kwargs = {}
|
||||
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
"This block supports text2image and image2image tasks for QwenImage."
|
||||
)
|
||||
|
||||
|
||||
# composing the inpainting denoising loops
|
||||
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
QwenImageLoopAfterDenoiserInpaint,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiserInpaint`\n"
|
||||
"This block supports inpainting tasks for QwenImage."
|
||||
)
|
||||
|
||||
|
||||
# composing the controlnet denoising loops
|
||||
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
QwenImageLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopBeforeDenoiserControlNet`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
"This block supports text2img/img2img tasks with controlnet for QwenImage."
|
||||
)
|
||||
|
||||
|
||||
# composing the controlnet denoising loops
|
||||
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
QwenImageLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
QwenImageLoopAfterDenoiserInpaint,
|
||||
]
|
||||
block_names = [
|
||||
"before_denoiser",
|
||||
"before_denoiser_controlnet",
|
||||
"denoiser",
|
||||
"after_denoiser",
|
||||
"after_denoiser_inpaint",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageLoopBeforeDenoiserControlNet`\n"
|
||||
" - `QwenImageLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiserInpaint`\n"
|
||||
"This block supports inpainting tasks with controlnet for QwenImage."
|
||||
)
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageEditLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageEditLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
"This block supports QwenImage Edit."
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
QwenImageLoopAfterDenoiserInpaint,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageEditLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageEditLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiserInpaint`\n"
|
||||
"This block supports inpainting tasks for QwenImage Edit."
|
||||
)
|
||||
857
src/diffusers/modular_pipelines/qwenimage/encoders.py
Normal file
857
src/diffusers/modular_pipelines/qwenimage/encoders.py
Normal file
@@ -0,0 +1,857 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist
|
||||
from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel
|
||||
from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import unwrap_module
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
|
||||
def get_qwen_prompt_embeds(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
prompt_template_encode_start_idx: int = 34,
|
||||
tokenizer_max_length: int = 1024,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
template = prompt_template_encode
|
||||
drop_idx = prompt_template_encode_start_idx
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = tokenizer(
|
||||
txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(device)
|
||||
encoder_hidden_states = text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = encoder_hidden_states.hidden_states[-1]
|
||||
|
||||
split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
def get_qwen_prompt_embeds_edit(
|
||||
text_encoder,
|
||||
processor,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
prompt_template_encode_start_idx: int = 64,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
template = prompt_template_encode
|
||||
drop_idx = prompt_template_encode_start_idx
|
||||
txt = [template.format(e) for e in prompt]
|
||||
|
||||
model_inputs = processor(
|
||||
text=txt,
|
||||
images=image,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
outputs = text_encoder(
|
||||
input_ids=model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
pixel_values=model_inputs.pixel_values,
|
||||
image_grid_thw=model_inputs.image_grid_thw,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image
|
||||
def encode_vae_image(
|
||||
image: torch.Tensor,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
generator: torch.Generator,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
latent_channels: int = 16,
|
||||
sample_mode: str = "argmax",
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise ValueError(f"Expected image to be a tensor, got {type(image)}.")
|
||||
|
||||
# preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
|
||||
if image.dim() == 4:
|
||||
image = image.unsqueeze(2)
|
||||
elif image.dim() != 5:
|
||||
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean)
|
||||
.view(1, latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(vae.config.latents_std)
|
||||
.view(1, latent_channels, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
image_latents = (image_latents - latents_mean) / latents_std
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
|
||||
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
|
||||
|
||||
This block resizes an input image tensor and exposes the resized result under configurable input and output
|
||||
names. Use this when you need to wire the resize step to different image fields (e.g., "image",
|
||||
"control_image")
|
||||
|
||||
Args:
|
||||
input_name (str, optional): Name of the image field to read from the
|
||||
pipeline state. Defaults to "image".
|
||||
output_name (str, optional): Name of the resized image field to write
|
||||
back to the pipeline state. Defaults to "resized_image".
|
||||
"""
|
||||
if not isinstance(input_name, str) or not isinstance(output_name, str):
|
||||
raise ValueError(
|
||||
f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
|
||||
)
|
||||
self._image_input_name = input_name
|
||||
self._resized_image_output_name = output_name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_resize_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
images = getattr(block_state, self._image_input_name)
|
||||
|
||||
if not is_valid_image_imagelist(images):
|
||||
raise ValueError(f"Images must be image or list of images but are {type(images)}")
|
||||
|
||||
if is_valid_image(images):
|
||||
images = [images]
|
||||
|
||||
image_width, image_height = images[0].size
|
||||
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
|
||||
|
||||
resized_images = [
|
||||
components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
|
||||
for image in images
|
||||
]
|
||||
|
||||
setattr(block_state, self._resized_image_output_name, resized_images)
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generate text_embeddings to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"),
|
||||
ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(
|
||||
name="prompt_template_encode",
|
||||
default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
),
|
||||
ConfigSpec(name="prompt_template_encode_start_idx", default=34),
|
||||
ConfigSpec(name="tokenizer_max_length", default=1024),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
|
||||
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
|
||||
InputParam(
|
||||
name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The encoder attention mask",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings mask",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(prompt, negative_prompt, max_sequence_length):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if (
|
||||
negative_prompt is not None
|
||||
and not isinstance(negative_prompt, str)
|
||||
and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length)
|
||||
|
||||
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds(
|
||||
components.text_encoder,
|
||||
components.tokenizer,
|
||||
prompt=block_state.prompt,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
tokenizer_max_length=components.config.tokenizer_max_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
|
||||
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or ""
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
|
||||
components.text_encoder,
|
||||
components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
tokenizer_max_length=components.config.tokenizer_max_length,
|
||||
device=device,
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[
|
||||
:, : block_state.max_sequence_length
|
||||
]
|
||||
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[
|
||||
:, : block_state.max_sequence_length
|
||||
]
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
|
||||
ComponentSpec("processor", Qwen2VLProcessor),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(
|
||||
name="prompt_template_encode",
|
||||
default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
),
|
||||
ConfigSpec(name="prompt_template_encode_start_idx", default=64),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
|
||||
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
|
||||
InputParam(
|
||||
name="resized_image",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The image prompt to encode, should be resized using resize step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The encoder attention mask",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_prompt_embeds_mask",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="The negative prompt embeddings mask",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(prompt, negative_prompt):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if (
|
||||
negative_prompt is not None
|
||||
and not isinstance(negative_prompt, str)
|
||||
and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.prompt, block_state.negative_prompt)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit(
|
||||
components.text_encoder,
|
||||
components.processor,
|
||||
prompt=block_state.prompt,
|
||||
image=block_state.resized_image,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or ""
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
|
||||
components.text_encoder,
|
||||
components.processor,
|
||||
prompt=negative_prompt,
|
||||
image=block_state.resized_image,
|
||||
prompt_template_encode=components.config.prompt_template_encode,
|
||||
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_mask_processor",
|
||||
InpaintProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("resized_image"),
|
||||
InputParam("image"),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="processed_image"),
|
||||
OutputParam(name="processed_mask_image"),
|
||||
OutputParam(
|
||||
name="mask_overlay_kwargs",
|
||||
type_hint=Dict,
|
||||
description="The kwargs for the postprocess step to apply the mask overlay",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if block_state.resized_image is None and block_state.image is None:
|
||||
raise ValueError("resized_image and image cannot be None at the same time")
|
||||
|
||||
if block_state.resized_image is None:
|
||||
image = block_state.image
|
||||
self.check_inputs(
|
||||
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
|
||||
)
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
else:
|
||||
width, height = block_state.resized_image[0].size
|
||||
image = block_state.resized_image
|
||||
|
||||
block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = (
|
||||
components.image_mask_processor.preprocess(
|
||||
image=image,
|
||||
mask=block_state.mask_image,
|
||||
height=height,
|
||||
width=width,
|
||||
padding_mask_crop=block_state.padding_mask_crop,
|
||||
)
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("resized_image"),
|
||||
InputParam("image"),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="processed_image"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if block_state.resized_image is None and block_state.image is None:
|
||||
raise ValueError("resized_image and image cannot be None at the same time")
|
||||
|
||||
if block_state.resized_image is None:
|
||||
image = block_state.image
|
||||
self.check_inputs(
|
||||
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
|
||||
)
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
else:
|
||||
width, height = block_state.resized_image[0].size
|
||||
image = block_state.resized_image
|
||||
|
||||
block_state.processed_image = components.image_processor.preprocess(
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_name: str = "processed_image",
|
||||
output_name: str = "image_latents",
|
||||
):
|
||||
"""Initialize a VAE encoder step for converting images to latent representations.
|
||||
|
||||
Both the input and output names are configurable so this block can be configured to process to different image
|
||||
inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
|
||||
|
||||
Args:
|
||||
input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
|
||||
Examples: "processed_image" or "processed_control_image"
|
||||
output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
|
||||
Examples: "image_latents" or "control_image_latents"
|
||||
|
||||
Examples:
|
||||
# Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
|
||||
|
||||
# Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
|
||||
input_name="processed_control_image", output_name="control_image_latents"
|
||||
)
|
||||
"""
|
||||
self._image_input_name = input_name
|
||||
self._image_latents_output_name = output_name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
]
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(self._image_input_name, required=True),
|
||||
InputParam("generator"),
|
||||
]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
self._image_latents_output_name,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image = getattr(block_state, self._image_input_name)
|
||||
|
||||
# Encode image into latents
|
||||
image_latents = encode_vae_image(
|
||||
image=image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
components = [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
ComponentSpec("controlnet", QwenImageControlNetModel),
|
||||
ComponentSpec(
|
||||
"control_image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
return components
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam("control_image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"control_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the control image",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
|
||||
controlnet = unwrap_module(components.controlnet)
|
||||
if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list):
|
||||
block_state.control_image = [block_state.control_image]
|
||||
|
||||
if isinstance(controlnet, QwenImageMultiControlNetModel):
|
||||
block_state.control_image_latents = []
|
||||
for control_image_ in block_state.control_image:
|
||||
control_image_ = components.control_image_processor.preprocess(
|
||||
image=control_image_,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
control_image_latents_ = encode_vae_image(
|
||||
image=control_image_,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
sample_mode="sample",
|
||||
)
|
||||
block_state.control_image_latents.append(control_image_latents_)
|
||||
|
||||
elif isinstance(controlnet, QwenImageControlNetModel):
|
||||
control_image = components.control_image_processor.preprocess(
|
||||
image=block_state.control_image,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
block_state.control_image_latents = encode_vae_image(
|
||||
image=control_image,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
sample_mode="sample",
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}"
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
431
src/diffusers/modular_pipelines/qwenimage/inputs.py
Normal file
431
src/diffusers/modular_pipelines/qwenimage/inputs.py
Normal file
@@ -0,0 +1,431 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import QwenImageMultiControlNetModel
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
def repeat_tensor_to_batch_size(
|
||||
input_name: str,
|
||||
input_tensor: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_images_per_prompt: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Repeat tensor elements to match the final batch size.
|
||||
|
||||
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
|
||||
by repeating each element along dimension 0.
|
||||
|
||||
The input tensor must have batch size 1 or batch_size. The function will:
|
||||
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
|
||||
- If batch size equals batch_size: repeat each element num_images_per_prompt times
|
||||
|
||||
Args:
|
||||
input_name (str): Name of the input tensor (used for error messages)
|
||||
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
|
||||
batch_size (int): The base batch size (number of prompts)
|
||||
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
|
||||
|
||||
Raises:
|
||||
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
|
||||
|
||||
Examples:
|
||||
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
|
||||
batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
|
||||
[4, 3]
|
||||
|
||||
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
|
||||
tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
|
||||
- shape: [4, 3]
|
||||
"""
|
||||
# make sure input is a tensor
|
||||
if not isinstance(input_tensor, torch.Tensor):
|
||||
raise ValueError(f"`{input_name}` must be a tensor")
|
||||
|
||||
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
|
||||
if input_tensor.shape[0] == 1:
|
||||
repeat_by = batch_size * num_images_per_prompt
|
||||
elif input_tensor.shape[0] == batch_size:
|
||||
repeat_by = num_images_per_prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
|
||||
)
|
||||
|
||||
# expand the tensor to match the batch_size * num_images_per_prompt
|
||||
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]:
|
||||
"""Calculate image dimensions from latent tensor dimensions.
|
||||
|
||||
This function converts latent space dimensions to image space dimensions by multiplying the latent height and width
|
||||
by the VAE scale factor.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
|
||||
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
|
||||
vae_scale_factor (int): The scale factor used by the VAE to compress images.
|
||||
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The calculated image dimensions as (height, width)
|
||||
|
||||
Raises:
|
||||
ValueError: If latents tensor doesn't have 4 or 5 dimensions
|
||||
|
||||
"""
|
||||
# make sure the latents are not packed
|
||||
if latents.ndim != 4 and latents.ndim != 5:
|
||||
raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")
|
||||
|
||||
latent_height, latent_width = latents.shape[-2:]
|
||||
|
||||
height = latent_height * vae_scale_factor
|
||||
width = latent_width * vae_scale_factor
|
||||
|
||||
return height, width
|
||||
|
||||
|
||||
class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
summary_section = (
|
||||
"Text input processing step that standardizes text embeddings for the pipeline.\n"
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps."
|
||||
|
||||
return summary_section + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
|
||||
InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(
|
||||
prompt_embeds,
|
||||
prompt_embeds_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask,
|
||||
):
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None")
|
||||
|
||||
if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None:
|
||||
raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`")
|
||||
|
||||
if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]:
|
||||
raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
|
||||
|
||||
elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]:
|
||||
raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`")
|
||||
|
||||
elif (
|
||||
negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]
|
||||
):
|
||||
raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
prompt_embeds=block_state.prompt_embeds,
|
||||
prompt_embeds_mask=block_state.prompt_embeds_mask,
|
||||
negative_prompt_embeds=block_state.negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask,
|
||||
)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len
|
||||
)
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
|
||||
1, block_state.num_images_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat(
|
||||
1, block_state.num_images_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
|
||||
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||
|
||||
This is a dynamic block that allows you to configure which inputs to process.
|
||||
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
|
||||
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to []. Examples: ["processed_mask_image"]
|
||||
|
||||
Examples:
|
||||
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="control_image_latents", required=True),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
if isinstance(components.controlnet, QwenImageMultiControlNetModel):
|
||||
control_image_latents = []
|
||||
# loop through each control_image_latents
|
||||
for i, control_image_latents_ in enumerate(block_state.control_image_latents):
|
||||
# 1. update height/width if not provided
|
||||
height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# 2. pack
|
||||
control_image_latents_ = components.pachifier.pack_latents(control_image_latents_)
|
||||
|
||||
# 3. repeat to match the batch size
|
||||
control_image_latents_ = repeat_tensor_to_batch_size(
|
||||
input_name=f"control_image_latents[{i}]",
|
||||
input_tensor=control_image_latents_,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
control_image_latents.append(control_image_latents_)
|
||||
|
||||
block_state.control_image_latents = control_image_latents
|
||||
|
||||
else:
|
||||
# 1. update height/width if not provided
|
||||
height, width = calculate_dimension_from_latents(
|
||||
block_state.control_image_latents, components.vae_scale_factor
|
||||
)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# 2. pack
|
||||
block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents)
|
||||
|
||||
# 3. repeat to match the batch size
|
||||
block_state.control_image_latents = repeat_tensor_to_batch_size(
|
||||
input_name="control_image_latents",
|
||||
input_tensor=block_state.control_image_latents,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
block_state.control_image_latents = block_state.control_image_latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
841
src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
Normal file
841
src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
Normal file
@@ -0,0 +1,841 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
QwenImageRoPEInputsStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
|
||||
from .denoise import (
|
||||
QwenImageControlNetDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
QwenImageEditDenoiseStep,
|
||||
QwenImageEditInpaintDenoiseStep,
|
||||
QwenImageInpaintControlNetDenoiseStep,
|
||||
QwenImageInpaintDenoiseStep,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageControlNetVaeEncoderStep,
|
||||
QwenImageEditResizeDynamicStep,
|
||||
QwenImageEditTextEncoderStep,
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
QwenImageProcessImagesInputStep,
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# 1. QwenImage
|
||||
|
||||
## 1.1 QwenImage/text2image
|
||||
|
||||
#### QwenImage/decode
|
||||
#### (standard decode step works for most tasks except for inpaint)
|
||||
QwenImageDecodeBlocks = InsertableDict(
|
||||
[
|
||||
("decode", QwenImageDecoderStep()),
|
||||
("postprocess", QwenImageProcessImagesOutputStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageDecodeBlocks.values()
|
||||
block_names = QwenImageDecodeBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image."
|
||||
|
||||
|
||||
#### QwenImage/text2image presets
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("input", QwenImageTextInputsStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 1.2 QwenImage/inpaint
|
||||
|
||||
#### QwenImage/inpaint vae encoder
|
||||
QwenImageInpaintVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
(
|
||||
"preprocess",
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
|
||||
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintVaeEncoderBlocks.values()
|
||||
block_names = QwenImageInpaintVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for inpainting tasks. It:\n"
|
||||
" - Resizes the image to the target size, based on `height` and `width`.\n"
|
||||
" - Processes and updates `image` and `mask_image`.\n"
|
||||
" - Creates `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
#### QwenImage/inpaint inputs
|
||||
QwenImageInpaintInputBlocks = InsertableDict(
|
||||
[
|
||||
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||
(
|
||||
"additional_inputs",
|
||||
QwenImageInputsDynamicStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintInputBlocks.values()
|
||||
block_names = QwenImageInpaintInputBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
# QwenImage/inpaint prepare latents
|
||||
QwenImageInpaintPrepareLatentsBlocks = InsertableDict(
|
||||
[
|
||||
("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||
("create_mask_latents", QwenImageCreateMaskLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintPrepareLatentsBlocks.values()
|
||||
block_names = QwenImageInpaintPrepareLatentsBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
|
||||
" - Add noise to the image latents to create the latents input for the denoiser.\n"
|
||||
" - Create the pachified latents `mask` based on the processedmask image.\n"
|
||||
)
|
||||
|
||||
|
||||
#### QwenImage/inpaint decode
|
||||
QwenImageInpaintDecodeBlocks = InsertableDict(
|
||||
[
|
||||
("decode", QwenImageDecoderStep()),
|
||||
("postprocess", QwenImageInpaintProcessImagesOutputStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintDecodeBlocks.values()
|
||||
block_names = QwenImageInpaintDecodeBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
|
||||
|
||||
|
||||
#### QwenImage/inpaint presets
|
||||
INPAINT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageInpaintVaeEncoderStep()),
|
||||
("input", QwenImageInpaintInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageInpaintDenoiseStep()),
|
||||
("decode", QwenImageInpaintDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 1.3 QwenImage/img2img
|
||||
|
||||
#### QwenImage/img2img vae encoder
|
||||
QwenImageImg2ImgVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", QwenImageProcessImagesInputStep()),
|
||||
("encode", QwenImageVaeEncoderDynamicStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = QwenImageImg2ImgVaeEncoderBlocks.values()
|
||||
block_names = QwenImageImg2ImgVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
#### QwenImage/img2img inputs
|
||||
QwenImageImg2ImgInputBlocks = InsertableDict(
|
||||
[
|
||||
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||
("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageImg2ImgInputBlocks.values()
|
||||
block_names = QwenImageImg2ImgInputBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
#### QwenImage/img2img presets
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageImg2ImgVaeEncoderStep()),
|
||||
("input", QwenImageImg2ImgInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
("denoise", QwenImageDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 1.4 QwenImage/controlnet
|
||||
|
||||
#### QwenImage/controlnet presets
|
||||
CONTROLNET_BLOCKS = InsertableDict(
|
||||
[
|
||||
("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image
|
||||
("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet
|
||||
(
|
||||
"controlnet_before_denoise",
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
), # before denoise step (after set_timesteps step)
|
||||
(
|
||||
"controlnet_denoise_loop_before",
|
||||
QwenImageLoopBeforeDenoiserControlNet(),
|
||||
), # controlnet loop step (insert before the denoiseloop_denoiser)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 1.5 QwenImage/auto encoders
|
||||
|
||||
|
||||
#### for inpaint and img2img tasks
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
|
||||
+ " - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# for controlnet tasks
|
||||
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetVaeEncoderStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
|
||||
+ " - if `control_image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 1.6 QwenImage/auto inputs
|
||||
|
||||
|
||||
# text2image/inpaint/img2img
|
||||
class QwenImageAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep]
|
||||
block_names = ["inpaint", "img2img", "text2image"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
|
||||
" This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n"
|
||||
+ " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# controlnet
|
||||
class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetInputsStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet input step that prepare the control_image_latents input.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n"
|
||||
+ " - if `control_image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 1.7 QwenImage/auto before denoise step
|
||||
# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step
|
||||
|
||||
# QwenImage/text2image before denoise
|
||||
QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task."
|
||||
|
||||
|
||||
# QwenImage/inpaint before denoise
|
||||
QwenImageInpaintBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageInpaintBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageInpaintBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
|
||||
# QwenImage/img2img before denoise
|
||||
QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
|
||||
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
|
||||
# auto before_denoise step for text2image, inpaint, img2img tasks
|
||||
class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageInpaintBeforeDenoiseStep,
|
||||
QwenImageImg2ImgBeforeDenoiseStep,
|
||||
QwenImageText2ImageBeforeDenoiseStep,
|
||||
]
|
||||
block_names = ["inpaint", "img2img", "text2image"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n"
|
||||
+ " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto before_denoise step for controlnet tasks
|
||||
class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetBeforeDenoiserStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet before denoise step that prepare the controlnet input.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n"
|
||||
+ " - if `control_image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 1.8 QwenImage/auto denoise
|
||||
|
||||
|
||||
# auto denoise step for controlnet tasks: works for all tasks with controlnet
|
||||
class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep]
|
||||
block_names = ["inpaint_denoise", "denoise"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet step during the denoising process. \n"
|
||||
" This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n"
|
||||
+ " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n"
|
||||
+ " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto denoise step for everything: works for all tasks with or without controlnet
|
||||
class QwenImageAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageControlNetAutoDenoiseStep,
|
||||
QwenImageInpaintDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
]
|
||||
block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
|
||||
block_trigger_inputs = ["control_image_latents", "mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n"
|
||||
+ " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n"
|
||||
+ " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n"
|
||||
+ " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
## 1.9 QwenImage/auto decode
|
||||
# auto decode step for inpaint and text2image tasks
|
||||
|
||||
|
||||
class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
|
||||
block_names = ["inpaint_decode", "decode"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the latents into images. \n"
|
||||
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
|
||||
+ " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
|
||||
+ " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
## 1.10 QwenImage/auto block & presets
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("input", QwenImageAutoInputStep()),
|
||||
("controlnet_input", QwenImageOptionalControlNetInputStep()),
|
||||
("before_denoise", QwenImageAutoBeforeDenoiseStep()),
|
||||
("controlnet_before_denoise", QwenImageOptionalControlNetBeforeDenoiseStep()),
|
||||
("denoise", QwenImageAutoDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# 2. QwenImage-Edit
|
||||
|
||||
## 2.1 QwenImage-Edit/edit
|
||||
|
||||
#### QwenImage-Edit/edit vl encoder: take both image and text prompts
|
||||
QwenImageEditVLEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("resize", QwenImageEditResizeDynamicStep()),
|
||||
("encode", QwenImageEditTextEncoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditVLEncoderBlocks.values()
|
||||
block_names = QwenImageEditVLEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Edit VL encoder step that encode the image an text prompts together."
|
||||
|
||||
|
||||
#### QwenImage-Edit/edit vae encoder
|
||||
QwenImageEditVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step
|
||||
("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image
|
||||
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditVaeEncoderBlocks.values()
|
||||
block_names = QwenImageEditVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
#### QwenImage-Edit/edit input
|
||||
QwenImageEditInputBlocks = InsertableDict(
|
||||
[
|
||||
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||
("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditInputBlocks.values()
|
||||
block_names = QwenImageEditInputBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the edit denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs: \n"
|
||||
" - `image_latents`.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
#### QwenImage/edit presets
|
||||
EDIT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditVaeEncoderStep()),
|
||||
("input", QwenImageEditInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditDenoiseStep()),
|
||||
("decode", QwenImageDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 2.2 QwenImage-Edit/edit inpaint
|
||||
|
||||
#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step
|
||||
QwenImageEditInpaintVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image
|
||||
(
|
||||
"preprocess",
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
|
||||
(
|
||||
"encode",
|
||||
QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
|
||||
), # processed_image -> image_latents
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditInpaintVaeEncoderBlocks.values()
|
||||
block_names = QwenImageEditInpaintVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
|
||||
" - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
|
||||
" - process the resized image and mask image.\n"
|
||||
" - create image latents."
|
||||
)
|
||||
|
||||
|
||||
#### QwenImage-Edit/edit inpaint presets
|
||||
EDIT_INPAINT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditInpaintVaeEncoderStep()),
|
||||
("input", QwenImageInpaintInputStep()),
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
("denoise", QwenImageEditInpaintDenoiseStep()),
|
||||
("decode", QwenImageInpaintDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
## 2.3 QwenImage-Edit/auto encoders
|
||||
|
||||
|
||||
class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageEditInpaintVaeEncoderStep,
|
||||
QwenImageEditVaeEncoderStep,
|
||||
]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations. \n"
|
||||
" This is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
|
||||
+ " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
|
||||
+ " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
|
||||
+ " - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 2.4 QwenImage-Edit/auto inputs
|
||||
class QwenImageEditAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit denoising step.\n"
|
||||
+ " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
|
||||
+ " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n"
|
||||
+ " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 2.5 QwenImage-Edit/auto before denoise
|
||||
# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step
|
||||
|
||||
#### QwenImage-Edit/edit before denoise
|
||||
QwenImageEditBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageEditBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
|
||||
|
||||
|
||||
#### QwenImage-Edit/edit inpaint before denoise
|
||||
QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||
("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
|
||||
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values()
|
||||
block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task."
|
||||
|
||||
|
||||
# auto before_denoise step for edit and edit_inpaint tasks
|
||||
class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintBeforeDenoiseStep,
|
||||
QwenImageEditBeforeDenoiseStep,
|
||||
]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n"
|
||||
+ " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
|
||||
+ " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 2.6 QwenImage-Edit/auto denoise
|
||||
|
||||
|
||||
class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
|
||||
block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep]
|
||||
block_names = ["inpaint_denoise", "denoise"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
+ "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n"
|
||||
+ " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
|
||||
+ " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## 2.7 QwenImage-Edit/auto blocks & presets
|
||||
|
||||
EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
|
||||
("input", QwenImageEditAutoInputStep()),
|
||||
("before_denoise", QwenImageEditAutoBeforeDenoiseStep()),
|
||||
("denoise", QwenImageEditAutoDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
|
||||
+ "- for edit (img2img) generation, you need to provide `image`\n"
|
||||
+ "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
)
|
||||
|
||||
|
||||
# 3. all block presets supported in QwenImage & QwenImage-Edit
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"img2img": IMAGE2IMAGE_BLOCKS,
|
||||
"edit": EDIT_BLOCKS,
|
||||
"edit_inpaint": EDIT_INPAINT_BLOCKS,
|
||||
"inpaint": INPAINT_BLOCKS,
|
||||
"controlnet": CONTROLNET_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
"edit_auto": EDIT_AUTO_BLOCKS,
|
||||
}
|
||||
202
src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
Normal file
202
src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import QwenImageLoraLoaderMixin
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
class QwenImagePachifier(ConfigMixin):
|
||||
"""
|
||||
A class to pack and unpack latents for QwenImage.
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def pack_latents(self, latents):
|
||||
if latents.ndim != 4 and latents.ndim != 5:
|
||||
raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}")
|
||||
|
||||
if latents.ndim == 4:
|
||||
latents = latents.unsqueeze(2)
|
||||
|
||||
batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
if latent_height % patch_size != 0 or latent_width % patch_size != 0:
|
||||
raise ValueError(
|
||||
f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
|
||||
)
|
||||
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
latent_height // patch_size,
|
||||
patch_size,
|
||||
latent_width // patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(
|
||||
0, 2, 4, 1, 3, 5
|
||||
) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
(latent_height // patch_size) * (latent_width // patch_size),
|
||||
num_channels_latents * patch_size * patch_size,
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def unpack_latents(self, latents, height, width, vae_scale_factor=8):
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
|
||||
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
|
||||
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
|
||||
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
height // patch_size,
|
||||
width // patch_size,
|
||||
channels // (patch_size * patch_size),
|
||||
patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for QwenImage.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 16
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def is_guidance_distilled(self):
|
||||
is_guidance_distilled = False
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
is_guidance_distilled = self.transformer.config.guidance_embeds
|
||||
return is_guidance_distilled
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
|
||||
class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for QwenImage-Edit.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental feature and is likely to change in the future.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
# YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step.
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 16
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def is_guidance_distilled(self):
|
||||
is_guidance_distilled = False
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
is_guidance_distilled = self.transformer.config.guidance_embeds
|
||||
return is_guidance_distilled
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
@@ -76,6 +76,7 @@ class StableDiffusionXLModularPipeline(
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
# YiYi TODO: change to num_channels_latents
|
||||
@property
|
||||
def num_channels_unet(self):
|
||||
num_channels_unet = 4
|
||||
|
||||
@@ -91,6 +91,14 @@ from .pag import (
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
QwenImageEditPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
)
|
||||
from .sana import SanaPipeline
|
||||
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
|
||||
from .stable_diffusion import (
|
||||
@@ -150,6 +158,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -174,6 +184,8 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux-controlnet", FluxControlNetImg2ImgPipeline),
|
||||
("flux-control", FluxControlImg2ImgPipeline),
|
||||
("flux-kontext", FluxKontextPipeline),
|
||||
("qwenimage", QwenImageImg2ImgPipeline),
|
||||
("qwenimage-edit", QwenImageEditPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -195,6 +207,8 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux-controlnet", FluxControlNetInpaintPipeline),
|
||||
("flux-control", FluxControlInpaintPipeline),
|
||||
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
|
||||
("qwenimage", QwenImageInpaintPipeline),
|
||||
("qwenimage-edit", QwenImageEditInpaintPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -32,6 +32,66 @@ class FluxModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageEditAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageEditModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user