1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
yiyixuxu
2026-01-10 10:52:53 +01:00
parent 7b499de6d0
commit b29873dee7
8 changed files with 126 additions and 109 deletions

View File

@@ -342,6 +342,18 @@ class InputParam:
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@classmethod
def template(cls, name: str) -> Optional["InputParam"]:
"""Get template for name if exists, otherwise None."""
if hasattr(cls, name) and callable(getattr(cls, name)):
return getattr(cls, name)()
return None
# ======================================================
# InputParam templates
# ======================================================
@classmethod
def prompt(cls) -> "InputParam":
return cls(name="prompt", type_hint=str, required=True,
@@ -383,7 +395,6 @@ class InputParam:
return cls(name="generator", type_hint=torch.Generator, default=None,
description="Torch generator for deterministic generation.")
@classmethod
def sigmas(cls) -> "InputParam":
return cls(name="sigmas", type_hint=List[float], default=None,
@@ -394,6 +405,7 @@ class InputParam:
return cls(name="strength", type_hint=float, default=default,
description="Strength for img2img/inpainting.")
# images
@classmethod
def image(cls) -> "InputParam":
return cls(name="image", type_hint=PIL.Image.Image, required=True,
@@ -425,12 +437,24 @@ class InputParam:
def timesteps(cls) -> "InputParam":
return cls(name="timesteps", type_hint=torch.Tensor, default=None,
description="Timesteps for the denoising process.")
# =====================================================================
# ControlNet
# =====================================================================
@classmethod
def output_type(cls) -> "InputParam":
return cls(name="output_type", type_hint=str, default="pil",
description="Output format: 'pil', 'np', 'pt''.")
@classmethod
def attention_kwargs(cls) -> "InputParam":
return cls(name="attention_kwargs", type_hint=Dict[str, Any], default=None,
description="Additional kwargs for attention processors.")
@classmethod
def denoiser_input_fields(cls) -> "InputParam":
return cls(kwargs_type="denoiser_input_fields", type_hint=torch.Tensor,
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.")
# ControlNet
@classmethod
def control_guidance_start(cls, default: float = 0.0) -> "InputParam":
return cls(name="control_guidance_start", type_hint=float, default=default,
@@ -446,18 +470,6 @@ class InputParam:
return cls(name="controlnet_conditioning_scale", type_hint=float, default=default,
description="Scale for ControlNet conditioning.")
@classmethod
def output_type(cls) -> "InputParam":
return cls(name="output_type", type_hint=str, default="pil",
description="Output format: 'pil', 'np', 'pt', or 'latent'.")
@classmethod
def attention_kwargs(cls) -> "InputParam":
return cls(name="attention_kwargs", type_hint=Dict[str, Any], default=None,
description="Additional kwargs for attention processors.")
@dataclass
class OutputParam:
"""Specification for an output parameter."""
@@ -472,6 +484,17 @@ class OutputParam:
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
)
@classmethod
def template(cls, name: str) -> Optional["OutputParam"]:
"""Get template for name if exists, otherwise None."""
if hasattr(cls, name) and callable(getattr(cls, name)):
return getattr(cls, name)()
return None
# ======================================================
# OutputParam templates
# ======================================================
@classmethod
def images(cls) -> "OutputParam":
return cls(name="images", type_hint=List[PIL.Image.Image],

View File

@@ -228,7 +228,7 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
InputParam.latents(),
InputParam.height(),
InputParam.width(),
InputParam(name="layers", type_hint=int, default=4),
InputParam(name="layers", type_hint=int, default=4, description="Number of layers to extract from the image"),
InputParam.num_images_per_prompt(),
InputParam.generator(),
InputParam(
@@ -598,7 +598,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The latents to use for the denoising process, used to calculate the image sequence length.",
),
InputParam(name="strength", default=0.9),
InputParam.strength(0.9),
]
@property
@@ -886,7 +886,7 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="layers", required=True),
InputParam(name="layers", default=4, description="Number of layers to extract from the image"),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),

View File

@@ -91,7 +91,7 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
InputParam("latents", required=True, type_hint=torch.Tensor),
InputParam("height", required=True, type_hint=int),
InputParam("width", required=True, type_hint=int),
InputParam("layers", required=True, type_hint=int),
InputParam("layers", default=4, description="Number of layers to extract from the image"),
]
@torch.no_grad()
@@ -141,11 +141,7 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
OutputParam.images()
]
@torch.no_grad()
@@ -198,14 +194,14 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor),
InputParam("output_type", default="pil", type_hint=str),
InputParam("latents", required=True, type_hint=torch.Tensor, description="The latents to decode, can be generated in the denoise step"),
InputParam.output_type(),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
OutputParam.images(),
]
@torch.no_grad()
@@ -273,12 +269,7 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("images", required=True, description="the generated image from decoders step"),
InputParam(
name="output_type",
default="pil",
type_hint=str,
description="The type of the output images, can be 'pil', 'np', 'pt'",
),
InputParam.output_type(),
]
@staticmethod
@@ -323,12 +314,7 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("images", required=True, description="the generated image from decoders step"),
InputParam(
name="output_type",
default="pil",
type_hint=str,
description="The type of the output images, can be 'pil', 'np', 'pt'",
),
InputParam.output_type(),
InputParam("mask_overlay_kwargs"),
]

View File

@@ -218,7 +218,7 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("attention_kwargs"),
InputParam.attention_kwargs(),
InputParam(
"latents",
required=True,
@@ -231,10 +231,7 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
InputParam.denoiser_input_fields(),
InputParam(
"img_shapes",
required=True,
@@ -322,7 +319,7 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("attention_kwargs"),
InputParam.attention_kwargs(),
InputParam(
"latents",
required=True,
@@ -335,10 +332,7 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
InputParam.denoiser_input_fields(),
InputParam(
"img_shapes",
required=True,
@@ -424,7 +418,7 @@ class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
OutputParam.latents(),
]
@torch.no_grad()

View File

@@ -301,8 +301,8 @@ class QwenImageEditResizeStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
InputParam.template(self._image_input_name) or InputParam(
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="Input image for conditioning"
),
]
@@ -381,7 +381,7 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
InputParam.template(self._image_input_name) or InputParam(
name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
),
InputParam(
@@ -484,7 +484,7 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
InputParam.template(self._image_input_name) or InputParam(
name=self._image_input_name,
required=True,
type_hint=torch.Tensor,
@@ -564,7 +564,7 @@ class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt", type_hint=str, description="The prompt to encode"),
InputParam(name="prompt", type_hint=str, description="The prompt to encode"), # it is not required for qwenimage-layered, unlike other pipelines
InputParam(
name="resized_image",
required=True,
@@ -647,11 +647,9 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
InputParam(
name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024
),
InputParam.prompt(),
InputParam.negative_prompt(),
InputParam.max_sequence_length(1024),
]
@property
@@ -772,8 +770,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
InputParam.prompt(),
InputParam.negative_prompt(),
InputParam(
name="resized_image",
required=True,
@@ -895,8 +893,8 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
InputParam.prompt(),
InputParam.negative_prompt(),
InputParam(
name="resized_cond_image",
required=True,
@@ -1010,11 +1008,11 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("mask_image", required=True),
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("padding_mask_crop"),
InputParam.mask_image(),
InputParam.image(),
InputParam.height(),
InputParam.width(),
InputParam.padding_mask_crop(),
]
@property
@@ -1082,9 +1080,9 @@ class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("mask_image", required=True),
InputParam("resized_image", required=True),
InputParam("padding_mask_crop"),
InputParam.mask_image(),
InputParam("resized_image", required=True, type_hint=PIL.Image.Image, description="The resized image. should be generated using a resize step"),
InputParam.padding_mask_crop(),
]
@property
@@ -1140,9 +1138,9 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
InputParam.image(),
InputParam.height(),
InputParam.width(),
]
@property
@@ -1312,7 +1310,10 @@ class QwenImageVaeEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [InputParam(self._image_input_name, required=True), InputParam("generator")]
return [
InputParam.template(self._image_input_name) or InputParam(name=self._image_input_name, required=True),
InputParam.generator(),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
@@ -1383,10 +1384,10 @@ class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam("control_image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("generator"),
InputParam.control_image(),
InputParam.height(),
InputParam.width(),
InputParam.generator(),
]
return inputs

View File

@@ -129,7 +129,7 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_images_per_prompt", default=1),
InputParam.num_images_per_prompt(),
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
@@ -269,17 +269,17 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam.num_images_per_prompt(),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
InputParam.height(),
InputParam.width(),
]
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
inputs.append(InputParam.template(input_name) or InputParam(name=input_name))
return inputs
@@ -398,17 +398,17 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam.num_images_per_prompt(),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
InputParam.height(),
InputParam.width(),
]
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
inputs.append(InputParam.template(input_name) or InputParam(name=input_name))
return inputs
@@ -544,15 +544,15 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam.num_images_per_prompt(),
InputParam(name="batch_size", required=True),
]
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
inputs.append(InputParam.template(input_name) or InputParam(name=input_name))
return inputs
@@ -638,9 +638,9 @@ class QwenImageControlNetInputsStep(ModularPipelineBlocks):
return [
InputParam(name="control_image_latents", required=True),
InputParam(name="batch_size", required=True),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="height"),
InputParam(name="width"),
InputParam.num_images_per_prompt(),
InputParam.height(),
InputParam.width(),
]
@torch.no_grad()

View File

@@ -54,7 +54,23 @@ logger = logging.get_logger(__name__)
# ====================
# 1. VAE ENCODER
# 1. TEXT ENCODER
# ====================
class QwenImageAutoTextEncoderStep(AutoPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageTextEncoderStep()]
block_names = ["text_encoder"]
block_trigger_inputs = ["prompt"]
@property
def description(self) -> str:
return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block."
" - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided."
" - if `prompt` is not provided, step will be skipped."
# ====================
# 2. VAE ENCODER
# ====================
@@ -118,7 +134,7 @@ class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
# ====================
# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
# ====================
@@ -396,7 +412,7 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
# ====================
# 3. DECODE
# 4. DECODE
# ====================
@@ -439,11 +455,11 @@ class QwenImageAutoDecodeStep(AutoPipelineBlocks):
# ====================
# 4. AUTO BLOCKS & PRESETS
# 5. AUTO BLOCKS & PRESETS
# ====================
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("text_encoder", QwenImageAutoTextEncoderStep()),
("vae_encoder", QwenImageAutoVaeEncoderStep()),
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
("denoise", QwenImageAutoCoreDenoiseStep()),

View File

@@ -129,10 +129,7 @@ class ZImageLoopDenoiser(ModularPipelineBlocks):
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
InputParam.denoiser_input_fields(),
]
guider_input_names = []
uncond_guider_input_names = []