1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
yiyixuxu
2026-01-19 08:10:31 +01:00
parent 1f2dbc9dd2
commit fb15752d55
10 changed files with 216 additions and 108 deletions

View File

@@ -397,6 +397,7 @@ INPUT_PARAM_TEMPLATES = {
"description": "Additional kwargs for attention processors.",
},
"denoiser_input_fields": {
"name": None,
"kwargs_type": "denoiser_input_fields",
"description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
},
@@ -509,6 +510,7 @@ OUTPUT_PARAM_TEMPLATES = {
}
@dataclass
class InputParam:
"""Specification for an input parameter."""
@@ -519,20 +521,22 @@ class InputParam:
description: str = ""
kwargs_type: str = None
def __post_init__(self):
if self.required and self.default is not None:
raise ValueError(f"InputParam '{self.name}' cannot be both required and have a default value")
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@classmethod
def template(cls, name: str, note: str = None, **overrides) -> "InputParam":
def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam":
"""Get template for name if exists, otherwise raise ValueError."""
if name not in INPUT_PARAM_TEMPLATES:
raise ValueError(f"InputParam template for {name} not found")
if template_name not in INPUT_PARAM_TEMPLATES:
raise ValueError(f"InputParam template for {template_name} not found")
template_kwargs = INPUT_PARAM_TEMPLATES[name].copy()
template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy()
# Determine the actual param name:
# 1. From overrides if provided
# 2. From template if present
# 3. Fall back to template_name
name = overrides.pop("name", template_kwargs.pop("name", template_name))
if note and "description" in template_kwargs:
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
@@ -541,6 +545,7 @@ class InputParam:
return cls(name=name, **template_kwargs)
@dataclass
class OutputParam:
"""Specification for an output parameter."""
@@ -555,12 +560,18 @@ class OutputParam:
)
@classmethod
def template(cls, name: str, note: str = None, **overrides) -> "OutputParam":
def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam":
"""Get template for name if exists, otherwise raise ValueError."""
if name not in OUTPUT_PARAM_TEMPLATES:
raise ValueError(f"OutputParam template for {name} not found")
if template_name not in OUTPUT_PARAM_TEMPLATES:
raise ValueError(f"OutputParam template for {template_name} not found")
template_kwargs = OUTPUT_PARAM_TEMPLATES[name].copy()
template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy()
# Determine the actual param name:
# 1. From overrides if provided
# 2. From template if present
# 3. Fall back to template_name
name = overrides.pop("name", template_kwargs.pop("name", template_name))
if note and "description" in template_kwargs:
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"

View File

@@ -146,8 +146,8 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="height", type_hint=int, description="updated to default value if not provided"),
OutputParam(name="width", type_hint=int, description="updated to default value if not provided"),
OutputParam(name="height", type_hint=int, description="if not set, updated to default value"),
OutputParam(name="width", type_hint=int, description="if not set, updated to default value"),
OutputParam(
name="latents",
type_hint=torch.Tensor,
@@ -230,8 +230,8 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="height", type_hint=int, description="updated to default value if not provided"),
OutputParam(name="width", type_hint=int, description="updated to default value if not provided"),
OutputParam(name="height", type_hint=int, description="if not set, updated to default value"),
OutputParam(name="width", type_hint=int, description="if not set, updated to default value"),
OutputParam(
name="latents",
type_hint=torch.Tensor,
@@ -307,8 +307,13 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The initial random noised, can be generated in prepare latent step.",
),
InputParam.template("image_latents", note="Can be generated from vae encoder and packed in input step."),
InputParam.template("timesteps", required=True, note="can be generated in set_timesteps step."),
InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."),
InputParam(
name="timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
]
@property
@@ -322,7 +327,7 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The scalednoisy latents to use for inpainting/image-to-image denoising.",
description="The scaled noisy latents to use for inpainting/image-to-image denoising.",
),
]
@@ -383,8 +388,8 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The processed mask to use for the inpainting process.",
),
InputParam.template("height", required=True, note="should be updated in prepare latents step."),
InputParam.template("width", required=True, note="should be updated in prepare latents step."),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("dtype"),
]
@@ -447,7 +452,12 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
return [
InputParam.template("num_inference_steps"),
InputParam.template("sigmas"),
InputParam.template("latents", required=True, description="The initial random noised latents for the denoising process, used to calculate the image sequence length. Can be generated in prepare latents step."),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The initial random noised latents for the denoising process. Can be generated in prepare latents step."
),
]
@property
@@ -456,7 +466,6 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
OutputParam(
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
),
OutputParam(name="num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
@@ -515,8 +524,11 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"),
OutputParam(name="num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"),
OutputParam(
name="timesteps",
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process."
),
]
@torch.no_grad()
@@ -568,7 +580,12 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
return [
InputParam.template("num_inference_steps"),
InputParam.template("sigmas"),
InputParam.template("latents", required=True, description="The latents to use for the denoising process. Can be generated in prepare latents step."),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process. Can be generated in prepare latents step."
),
InputParam.template("strength", default=0.9),
]
@@ -583,7 +600,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
OutputParam(
name="num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
description="The number of denoising steps to perform at inference time. Updated based on strength.",
),
]
@@ -643,8 +660,8 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam.template("batch_size"),
InputParam.template("height", note="should be updated in prepare latents step."),
InputParam.template("width", note="should be updated in prepare latents step."),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds_mask"),
]
@@ -711,8 +728,8 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
InputParam.template("batch_size"),
InputParam(name="image_height", required=True, type_hint=int, description="The height of the reference image. Can be generated in input step."),
InputParam(name="image_width", required=True, type_hint=int, description="The width of the reference image. Can be generated in input step."),
InputParam.template("height", required=True, note="should be updated in prepare latents step."),
InputParam.template("width", required=True, note="should be updated in prepare latents step."),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds_mask"),
]
@@ -788,10 +805,10 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam.template("batch_size"),
InputParam(name="image_height", required=True, type_hint=List[int], descrption="The heights of the reference images. Can be generated in input step."),
InputParam(name="image_height", required=True, type_hint=List[int], description="The heights of the reference images. Can be generated in input step."),
InputParam(name="image_width", required=True, type_hint=List[int], description="The widths of the reference images. Can be generated in input step."),
InputParam.template("height", required=True, note="should be updated in prepare latents step."),
InputParam.template("width", required=True, note="should be updated in prepare latents step."),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds_mask"),
]
@@ -863,8 +880,8 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
return [
InputParam.template("batch_size"),
InputParam.template("layers"),
InputParam.template("height", required=True, note="should be updated in prepare latents step."),
InputParam.template("width", required=True, note="should be updated in prepare latents step."),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds_mask"),
]
@@ -950,8 +967,18 @@ class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
InputParam.template("control_guidance_start"),
InputParam.template("control_guidance_end"),
InputParam.template("controlnet_conditioning_scale"),
InputParam("control_image_latents", required=True, type_hint=torch.Tensor, description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step."),
InputParam.template("timesteps", required=True, note="Can be generated in set_timesteps step."),
InputParam(
name="control_image_latents",
required=True,
type_hint=torch.Tensor,
description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step."
),
InputParam(
name="timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
]
@property

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Any, Dict, List
import torch
@@ -47,15 +47,24 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.template("height", required=True, note="should be updated in input and prepare latents step."),
InputParam.template("width", required=True, note="should be updated in input and prepare latents step."),
InputParam.template("latents", required=True, description="The latents to decode, can be generated in the denoise step."),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to decode, can be generated in the denoise step."
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam.template("latents", note="unpacked to B, C, 1, H, W"),
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The denoisedlatents unpacked to B, C, 1, H, W"
),
]
@torch.no_grad()
@@ -87,9 +96,14 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.template("latents", required=True, description="The latents to decode, can be generated in the denoise step."),
InputParam.template("height", required=True, note="should be updated in prepare latents step."),
InputParam.template("width", required=True, note="should be updated in prepare latents step."),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents to decode, can be generated in the denoise step."
),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("layers"),
]
@@ -135,7 +149,12 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.template("latents", required=True, description="The latents to decode, can be generated in the denoise step and unpacked in the after denoise step."),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step."
),
]
@property
@@ -192,7 +211,12 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.template("latents", required=True, description="The latents to decode, can be generated in the denoise step and unpacked in the after denoise step."),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step."
),
InputParam.template("output_type"),
]
@@ -266,7 +290,12 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("images", required=True, description="the generated image tensor from decoders step"),
InputParam(
name="images",
required=True,
type_hint=torch.Tensor,
description="the generated image tensor from decoders step"
),
InputParam.template("output_type"),
]
@@ -315,9 +344,17 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("images", required=True, description="the generated image tensor from decoders step"),
InputParam(
name="images",
required=True,
type_hint=torch.Tensor,
description="the generated image tensor from decoders step"
),
InputParam.template("output_type"),
InputParam("mask_overlay_kwargs", description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep."),
InputParam(
name="mask_overlay_kwargs",
type_hint=Dict[str, Any],
description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep."),
]
@property

View File

@@ -49,7 +49,12 @@ class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.template("latents", required=True, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
]
@torch.no_grad()
@@ -74,8 +79,13 @@ class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam.template("latents", required=True, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."),
InputParam.template("image_latents", note="Can be encoded in vae_encoder step and packed in prepare_image_latents step."),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
InputParam.template("image_latents", note="generated in vae encoder step and updated in input step."),
]
@torch.no_grad()
@@ -119,10 +129,13 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam.template("controlnet_conditioning_scale", note="Can be generated in prepare_controlnet_inputs step."),
InputParam.template("controlnet_keep", note="Can be generated in prepare_controlnet_inputs step."),
InputParam.template("num_inference_steps", required=True, note="Can be updated in set_timesteps step."),
InputParam.template("denoiser_input_fields")
InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."),
InputParam(
name="controlnet_keep",
required=True,
type_hint=List[float],
description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step."
),
]
@torch.no_grad()
@@ -184,8 +197,13 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam.template("attention_kwargs"),
InputParam.template("latents", required=True, description="The latents to use for the denoising process. Can be generated in prepare_latents step."),
InputParam.template("num_inference_steps", required=True, note="should be updated in set_timesteps step."),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process. Can be generated in prepare_latents step."
),
InputParam.template("num_inference_steps"),
InputParam.template("denoiser_input_fields"),
InputParam(
"img_shapes",
@@ -275,8 +293,13 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam.template("attention_kwargs"),
InputParam.template("latents", required=True, description="The latents to use for the denoising process. Can be generated in prepare_latents step."),
InputParam.template("num_inference_steps", required=True, note="should be updated in set_timesteps step."),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process. Can be generated in prepare_latents step."
),
InputParam.template("num_inference_steps"),
InputParam.template("denoiser_input_fields"),
InputParam(
"img_shapes",
@@ -404,14 +427,19 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
),
InputParam.template("image_latents", note="Can be generated from vae encoder and packed in input step."),
InputParam.template("image_latents", note="Can be generated from vae encoder step and updated in input step."),
InputParam(
"initial_noise",
required=True,
type_hint=torch.Tensor,
description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
),
InputParam.template("timesteps", required=True, note="should be updated in set_timesteps step."),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
]
@torch.no_grad()
@@ -452,8 +480,13 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam.template("timesteps", required=True, note="should be generated in set_timesteps step."),
InputParam.template("num_inference_steps", required=True, note="should be updated in set_timesteps step."),
InputParam(
name="timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
),
InputParam.template("num_inference_steps", required=True),
]
@torch.no_grad()

View File

@@ -1145,7 +1145,7 @@ class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks):
@property
def description(self) -> str:
return "Image Preprocess step. Images can be resized first using QwenImageEditResizeStep."
return "Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of processed images."
@property
def expected_components(self) -> List[ComponentSpec]:

View File

@@ -139,8 +139,8 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam.template("batch_size"),
OutputParam.template("dtype"),
OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"),
OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"),
OutputParam.template("prompt_embeds", note="batch-expanded"),
OutputParam.template("prompt_embeds_mask", note="batch-expanded"),
OutputParam.template("negative_prompt_embeds", note="batch-expanded"),
@@ -307,8 +307,8 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
# `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided
if len(self._image_latent_inputs) > 0:
outputs.append(OutputParam(name="height", type_hint=int, note="updated based on image size if not provided"))
outputs.append(OutputParam(name="width", type_hint=int, note="updated based on image size if not provided"))
outputs.append(OutputParam(name="height", type_hint=int, description="if not provided, updated to image height"))
outputs.append(OutputParam(name="width", type_hint=int, description="if not provided, updated to image width"))
# image latent inputs are modified in place (patchified and batch-expanded)
for input_param in self._image_latent_inputs:
@@ -476,8 +476,8 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
# `height`/`width` are updated if any image latent inputs are provided
if len(self._image_latent_inputs) > 0:
outputs.append(OutputParam(name="height", type_hint=int, description="updated based on image size if not provided"))
outputs.append(OutputParam(name="width", type_hint=int, description="updated based on image size if not provided"))
outputs.append(OutputParam(name="height", type_hint=int, description="if not provided, updated to image height"))
outputs.append(OutputParam(name="width", type_hint=int, description="if not provided, updated to image width"))
# image latent inputs are modified in place (patchified, concatenated, and batch-expanded)
for input_param in self._image_latent_inputs:
@@ -658,8 +658,8 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
]
if len(self._image_latent_inputs) > 0:
outputs.append(OutputParam(name="height", type_hint=int, description="updated based on image size if not provided"))
outputs.append(OutputParam(name="width", type_hint=int, description="updated based on image size if not provided"))
outputs.append(OutputParam(name="height", type_hint=int, description="if not provided, updated to image height"))
outputs.append(OutputParam(name="width", type_hint=int, description="if not provided, updated to image width"))
# Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded)
for input_param in self._image_latent_inputs:
@@ -759,8 +759,8 @@ class QwenImageControlNetInputsStep(ModularPipelineBlocks):
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="control_image_latents", type_hint=torch.Tensor, description="The control image latents (patchified and batch-expanded)."),
OutputParam(name="height", type_hint=int, description="updated based on control image size if not provided"),
OutputParam(name="width", type_hint=int, description="updated based on control image size if not provided"),
OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"),
OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"),
]
@torch.no_grad()

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from ..modular_pipeline_utils import InsertableDict, OutputParam, InputParam
from .before_denoise import (
QwenImageControlNetBeforeDenoiserStep,
QwenImageCreateMaskLatentsStep,
@@ -319,7 +319,7 @@ class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
"""
model_name = "qwenimage"
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])]
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()]
block_names = ["text_inputs", "additional_inputs"]
@property
@@ -373,7 +373,7 @@ class QwenImageInpaintInputStep(SequentialPipelineBlocks):
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
additional_batch_inputs=[InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image")]
),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -512,7 +512,7 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -598,7 +598,7 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -682,7 +682,7 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -777,7 +777,7 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -880,7 +880,7 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -981,7 +981,7 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -1042,7 +1042,7 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -1279,5 +1279,5 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.images(),
OutputParam.template("images"),
]

View File

@@ -13,10 +13,11 @@
# limitations under the License.
from typing import Optional
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from ..modular_pipeline_utils import InsertableDict, OutputParam, InputParam
from .before_denoise import (
QwenImageCreateMaskLatentsStep,
QwenImageEditRoPEInputsStep,
@@ -206,7 +207,7 @@ class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
block_classes = [
QwenImageEditResizeStep(),
QwenImageEditInpaintProcessImagesInputStep(),
QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"),
QwenImageVaeEncoderStep(),
]
block_names = ["resize", "preprocess", "encode"]
@@ -286,7 +287,7 @@ class QwenImageEditInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
QwenImageAdditionalInputsStep(),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -344,8 +345,7 @@ class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
QwenImageAdditionalInputsStep(additional_batch_inputs=[InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image")]
),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -485,7 +485,7 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -571,7 +571,7 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -605,7 +605,7 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -698,7 +698,7 @@ class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -816,5 +816,5 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.images(),
OutputParam.template("images"),
]

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from ..modular_pipeline_utils import InsertableDict, OutputParam, InputParam
from .before_denoise import (
QwenImageEditPlusRoPEInputsStep,
QwenImagePrepareLatentsStep,
@@ -211,7 +211,7 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageTextInputsStep(),
QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]),
QwenImageEditPlusAdditionalInputsStep(),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -302,7 +302,7 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -446,5 +446,5 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.images(),
OutputParam.template("images"),
]

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
@@ -255,7 +255,7 @@ class QwenImageLayeredInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-layered"
block_classes = [
QwenImageTextInputsStep(),
QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]),
QwenImageLayeredAdditionalInputsStep(),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -342,7 +342,7 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.latents(),
OutputParam.template("latents"),
]
@@ -484,5 +484,5 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam.images(),
OutputParam.template("images"),
]