diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 5ef1b98f1b..6f1010daf2 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -898,12 +898,12 @@ def make_doc_string( # Add components section if provided if expected_components and len(expected_components) > 0: - components_str = format_components(expected_components, indent_level=2) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) output += components_str + "\n\n" # Add configs section if provided if expected_configs and len(expected_configs) > 0: - configs_str = format_configs(expected_configs, indent_level=2) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) output += configs_str + "\n\n" # Add inputs section diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index fc795b5f5a..0b8cd0f4b2 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -117,8 +117,39 @@ def get_timesteps(scheduler, num_inference_steps, strength): # 1. PREPARE LATENTS # ==================== - +# auto_docstring class QwenImagePrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise for the generation process + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ model_name = "qwenimage" @property @@ -201,7 +232,41 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise (B, layers+1, C, H, W) for the generation process + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ model_name = "qwenimage-layered" @property @@ -285,7 +350,29 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): + """ + Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be generated from + vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + """ model_name = "qwenimage" @property @@ -366,7 +453,28 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): + """ + Step that creates mask latents from preprocessed mask_image by interpolating to latent space. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + mask (`Tensor`): + The mask to use for the inpainting process. + """ model_name = "qwenimage" @property @@ -433,8 +541,26 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): # 2. SET TIMESTEPS # ==================== - +# auto_docstring class QwenImageSetTimestepsStep(ModularPipelineBlocks): + """ + Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The initial random noised latents for the denoising process. Can be generated in prepare latents step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process + """ model_name = "qwenimage" @property @@ -500,7 +626,27 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): + """ + Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be generated from + vae encoder and packed in input step.) + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + """ model_name = "qwenimage-layered" @property @@ -562,7 +708,30 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + """ + Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The latents to use for the denoising process. Can be generated in prepare latents step. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + num_inference_steps (`int`): + The number of denoising steps to perform at inference time. Updated based on strength. + """ model_name = "qwenimage" @property @@ -646,8 +815,32 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): ## RoPE inputs for denoiser - +# auto_docstring class QwenImageRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the images latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ model_name = "qwenimage" @property @@ -715,7 +908,36 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + image_height (`int`): + The height of the reference image. Can be generated in input step. + image_width (`int`): + The width of the reference image. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the images latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ model_name = "qwenimage" @property @@ -790,7 +1012,38 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus. + Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. + Should be placed after prepare_latents step. + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + image_height (`List`): + The heights of the reference images. Can be generated in input step. + image_width (`List`): + The widths of the reference images. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ model_name = "qwenimage-edit-plus" @property @@ -866,7 +1119,36 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + additional_t_cond (`Tensor`): + The additional t cond, used for RoPE calculation + """ model_name = "qwenimage-layered" @property @@ -948,7 +1230,31 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): ## ControlNet inputs for denoiser + +# auto_docstring class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): + """ + step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step. + + Components: + controlnet (`QwenImageControlNetModel`) + + Inputs: + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step. + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + controlnet_keep (`List`): + The controlnet keep values + """ model_name = "qwenimage" @property diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 4476e1db9b..650bf34da7 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -29,7 +29,27 @@ logger = logging.get_logger(__name__) # after denoising loop (unpack latents) + +#auto_docstring class QwenImageAfterDenoiseStep(ModularPipelineBlocks): + """ + Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width) + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + latents (`Tensor`): + The latents to decode, can be generated in the denoise step. + + Outputs: + latents (`Tensor`): + The denoisedlatents unpacked to B, C, 1, H, W + """ model_name = "qwenimage" @property @@ -80,7 +100,28 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks): return components, state +#auto_docstring class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + + Outputs: + latents (`Tensor`): + Denoised latents. (unpacked to B, C, layers+1, H, W) + """ model_name = "qwenimage-layered" @property @@ -131,7 +172,23 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): # decode step + +#auto_docstring class QwenImageDecoderStep(ModularPipelineBlocks): + """ + Step that decodes the latents to images + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ model_name = "qwenimage" @property @@ -189,7 +246,25 @@ class QwenImageDecoderStep(ModularPipelineBlocks): return components, state +#auto_docstring class QwenImageLayeredDecoderStep(ModularPipelineBlocks): + """ + Decode unpacked latents (B, C, layers+1, H, W) into layer images. + + Components: + vae (`AutoencoderKLQwenImage`) + image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ model_name = "qwenimage-layered" @property @@ -269,7 +344,25 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks): # postprocess the decoded images + +#auto_docstring class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ model_name = "qwenimage" @property @@ -323,7 +416,26 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): return components, state +#auto_docstring class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image, optional apply the mask overally to the original image.. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ model_name = "qwenimage" @property diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index ad6a9677ac..ff6e411d76 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -85,7 +85,7 @@ class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks): type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), - InputParam.template("image_latents", note="generated in vae encoder step and updated in input step."), + InputParam.template("image_latents"), ] @torch.no_grad() @@ -197,13 +197,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks): def inputs(self) -> List[InputParam]: return [ InputParam.template("attention_kwargs"), - InputParam( - name="latents", - required=True, - type_hint=torch.Tensor, - description="The latents to use for the denoising process. Can be generated in prepare_latents step." - ), - InputParam.template("num_inference_steps"), InputParam.template("denoiser_input_fields"), InputParam( "img_shapes", @@ -293,13 +286,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks): def inputs(self) -> List[InputParam]: return [ InputParam.template("attention_kwargs"), - InputParam( - name="latents", - required=True, - type_hint=torch.Tensor, - description="The latents to use for the denoising process. Can be generated in prepare_latents step." - ), - InputParam.template("num_inference_steps"), InputParam.template("denoiser_input_fields"), InputParam( "img_shapes", @@ -427,19 +413,19 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks): type_hint=torch.Tensor, description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.", ), - InputParam.template("image_latents", note="Can be generated from vae encoder step and updated in input step."), + InputParam.template("image_latents"), InputParam( "initial_noise", required=True, type_hint=torch.Tensor, description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.", ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("latents"), ] @torch.no_grad() @@ -521,6 +507,38 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): # auto_docstring class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2image and image2image tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) + transformer (`QwenImageTransformer2DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ model_name = "qwenimage" block_classes = [ @@ -546,6 +564,45 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image (inpainting) # auto_docstring class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) + transformer (`QwenImageTransformer2DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -572,6 +629,46 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image (text2image, image2image) with controlnet # auto_docstring class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2img/img2img tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) + controlnet (`QwenImageControlNetModel`) + transformer (`QwenImageTransformer2DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`List`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -598,6 +695,53 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image (inpainting) with controlnet # auto_docstring class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) + controlnet (`QwenImageControlNetModel`) + transformer (`QwenImageTransformer2DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`List`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, @@ -632,6 +776,40 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image Edit (image2image) # auto_docstring class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) + transformer (`QwenImageTransformer2DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, @@ -656,6 +834,45 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image Edit (inpainting) # auto_docstring class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) + transformer (`QwenImageTransformer2DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, @@ -682,6 +899,40 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): # Qwen Image Layered (image2image) # auto_docstring class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method + At each iteration, it runs blocks defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Layered. + + Components: + guider (`ClassifierFreeGuidance`) + transformer (`QwenImageTransformer2DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ model_name = "qwenimage-layered" block_classes = [ QwenImageEditLoopBeforeDenoiser, diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 9a83f0d717..083ee507cc 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -276,7 +276,23 @@ def encode_vae_image( # # In most of our other pipelines, resizing is done as part of the image preprocessing step. # ==================== + +# auto_docstring class QwenImageEditResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to target area while maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`List`): + The resized images + """ model_name = "qwenimage-edit" @@ -334,7 +350,24 @@ class QwenImageEditResizeStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageLayeredResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + + Outputs: + resized_image (`List`): + The resized images + """ model_name = "qwenimage-layered" @property @@ -405,7 +438,26 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusResizeStep(ModularPipelineBlocks): + """ + Resize images for QwenImage Edit Plus pipeline. + Produces two outputs: resized_image (1024x1024) for VAE encoding, resized_cond_image (384x384) for VL text encoding. + Each image is resized independently based on its own aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + """ model_name = "qwenimage-edit-plus" @@ -488,7 +540,30 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks): # ==================== # 2. GET IMAGE PROMPT # ==================== + +# auto_docstring class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): + """ + Auto-caption step that generates a text prompt from the input image if none is provided. + Uses the VL model (text_encoder) to generate a description of the image. + If prompt is already provided, this step passes through unchanged. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + processor (`Qwen2VLProcessor`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + resized_image (`Image`): + The image to generate caption from, should be resized use the resize step + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + + Outputs: + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + """ model_name = "qwenimage-layered" @@ -530,6 +605,16 @@ class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): ), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="prompt", + type_hint=str, + description="The prompt or prompts to guide image generation. If not provided, updated using image caption", + ), + ] + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -567,7 +652,35 @@ class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): # ==================== # 3. TEXT ENCODER # ==================== + +# auto_docstring class QwenImageTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that generates text embeddings to guide the image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use + tokenizer (`Qwen2Tokenizer`): The tokenizer to use + guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage" def __init__(self): @@ -670,7 +783,34 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + processor (`Qwen2VLProcessor`) + guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_image (`Image`): + The image prompt to encode, should be resized using resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage" def __init__(self): @@ -766,7 +906,34 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together to generate text embeddings for guiding image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + processor (`Qwen2VLProcessor`) + guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_cond_image (`Tensor`): + The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ model_name = "qwenimage-edit-plus" @@ -874,7 +1041,35 @@ class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): # ==================== # 4. IMAGE PREPROCESS # ==================== + +# auto_docstring class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be resized to the given height and width. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + """ model_name = "qwenimage" @property @@ -954,7 +1149,30 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be resized first. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + resized_image (`Image`): + The resized image. should be generated using a resize step + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + """ model_name = "qwenimage-edit" @property @@ -1025,7 +1243,26 @@ class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. will resize the image to the given height and width. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + processed_image (`Tensor`): + The processed image + """ model_name = "qwenimage" @property @@ -1087,7 +1324,22 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images needs to be resized first. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`List`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ model_name = "qwenimage-edit" @property @@ -1140,7 +1392,22 @@ class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of processed images. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`List`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ model_name = "qwenimage-edit-plus" @property @@ -1204,8 +1471,26 @@ class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): # ==================== # 5. VAE ENCODER # ==================== + +# auto_docstring class QwenImageVaeEncoderStep(ModularPipelineBlocks): - """VAE encoder that handles both single images and lists of images with varied resolutions.""" + """ + VAE Encoder step that converts processed_image into latent representations image_latents. + Handles both single images and lists of images with varied resolutions. + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + processed_image (`Tensor`): + The image tensor to encode + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + """ model_name = "qwenimage" @@ -1297,7 +1582,30 @@ class QwenImageVaeEncoderStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts `control_image` into latent representations control_image_latents. + + Components: + vae (`AutoencoderKLQwenImage`) + controlnet (`QwenImageControlNetModel`) + control_image_processor (`VaeImageProcessor`) + + Inputs: + control_image (`Image`): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ model_name = "qwenimage" @property @@ -1411,7 +1719,20 @@ class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): # ==================== # 6. PERMUTE LATENTS # ==================== + +# auto_docstring class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks): + """ + Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing. + + Inputs: + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. (permuted from [B, C, 1, H, W] to [B, 1, C, H, W]) + """ model_name = "qwenimage-layered" @property diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index b237031b91..0e03242e5e 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -109,7 +109,42 @@ def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: in return height, width +# auto_docstring class QwenImageTextInputsStep(ModularPipelineBlocks): + """ + Text input processing step that standardizes text embeddings for the pipeline. + This step: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt) + + This block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps. + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + """ model_name = "qwenimage" @property @@ -217,8 +252,47 @@ class QwenImageTextInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage: update height/width, expand batch, patchify.""" + """ + Input processing step that: + 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ model_name = "qwenimage" @@ -385,8 +459,48 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage Edit Plus: handles list of latents with different sizes.""" + """ + Input processing step for Edit Plus that: + 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch + 2. For additional batch inputs: Expands batch dimensions to match final batch size + Height/width defaults to last image in the list. + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`List`): + The image heights calculated from the image latents dimension + image_width (`List`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ model_name = "qwenimage-edit-plus" @@ -571,8 +685,44 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): # same as QwenImageAdditionalInputsStep, but with layered pachifier. + +# auto_docstring class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): - """Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier.""" + """ + Input processing step for Layered that: + 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified with layered + pachifier and batch-expanded) + """ model_name = "qwenimage-layered" @@ -738,7 +888,32 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): return components, state +# auto_docstring class QwenImageControlNetInputsStep(ModularPipelineBlocks): + """ + prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps. + + Inputs: + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + control_image_latents (`Tensor`): + The control image latents (patchified and batch-expanded). + height (`int`): + if not provided, updated to control image height + width (`int`): + if not provided, updated to control image width + """ model_name = "qwenimage" @property diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py index 46f0b6f6ff..b50e41bb50 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -65,26 +65,10 @@ class QwenImageAutoTextEncoderStep(AutoPipelineBlocks): Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block. Components: - text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use - tokenizer (`Qwen2Tokenizer`): The tokenizer to use - guider (`ClassifierFreeGuidance`) - Configs: - - prompt_template_encode (default: <|im_start|>system - Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|> - <|im_start|>user - {}<|im_end|> - <|im_start|>assistant - ) - - prompt_template_encode_start_idx (default: 34) - - tokenizer_max_length (default: 1024) - Inputs: prompt (`str`, *optional*): The prompt or prompts to guide image generation. @@ -95,13 +79,13 @@ class QwenImageAutoTextEncoderStep(AutoPipelineBlocks): Outputs: prompt_embeds (`Tensor`): - The prompt embeddings + The prompt embeddings. prompt_embeds_mask (`Tensor`): - The encoder attention mask + The encoder attention mask. negative_prompt_embeds (`Tensor`): - The negative prompt embeddings + The negative prompt embeddings. negative_prompt_embeds_mask (`Tensor`): - The negative prompt embeddings mask + The negative prompt embeddings mask. """ model_name = "qwenimage" @@ -130,16 +114,14 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): - Creates `image_latents`. Components: - image_mask_processor (`InpaintProcessor`) - vae (`AutoencoderKLQwenImage`) Inputs: mask_image (`Image`): Mask image for inpainting. - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): @@ -150,14 +132,14 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): Torch generator for deterministic generation. Outputs: - processed_image (`None`): - TODO: Add description. - processed_mask_image (`None`): - TODO: Add description. + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image mask_overlay_kwargs (`Dict`): The kwargs for the postprocess step to apply the mask overlay image_latents (`Tensor`): - The latents representing the reference image(s). Single tensor or list depending on input. + The latent representation of the input image. """ model_name = "qwenimage" @@ -180,14 +162,12 @@ class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): Vae encoder step that preprocess andencode the image inputs into their latent representations. Components: - image_processor (`VaeImageProcessor`) - vae (`AutoencoderKLQwenImage`) Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): @@ -196,10 +176,10 @@ class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): Torch generator for deterministic generation. Outputs: - processed_image (`None`): - TODO: Add description. + processed_image (`Tensor`): + The processed image image_latents (`Tensor`): - The latents representing the reference image(s). Single tensor or list depending on input. + The latent representation of the input image. """ model_name = "qwenimage" @@ -238,11 +218,8 @@ class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): - if `control_image` is not provided, step will be skipped. Components: - vae (`AutoencoderKLQwenImage`) - controlnet (`QwenImageControlNetModel`) - control_image_processor (`VaeImageProcessor`) Inputs: @@ -286,36 +263,50 @@ class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): Input step that prepares the inputs for the img2img denoising step. It: Components: - pachifier (`QwenImagePachifier`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. Outputs: batch_size (`int`): - Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + The batch size of the prompt embeddings dtype (`dtype`): - Data type of model tensor inputs (determined by `prompt_embeds`) + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) image_height (`int`): The image height calculated from the image latents dimension image_width (`int`): The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) """ model_name = "qwenimage" @@ -335,38 +326,54 @@ class QwenImageInpaintInputStep(SequentialPipelineBlocks): Input step that prepares the inputs for the inpainting denoising step. It: Components: - pachifier (`QwenImagePachifier`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. - processed_mask_image (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image Outputs: batch_size (`int`): - Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + The batch size of the prompt embeddings dtype (`dtype`): - Data type of model tensor inputs (determined by `prompt_embeds`) + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) image_height (`int`): The image height calculated from the image latents dimension image_width (`int`): The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) """ model_name = "qwenimage" @@ -394,30 +401,31 @@ class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): - Create the pachified latents `mask` based on the processedmask image. Components: - scheduler (`FlowMatchEulerDiscreteScheduler`) - pachifier (`QwenImagePachifier`) Inputs: latents (`Tensor`): The initial random noised, can be generated in prepare latent step. image_latents (`Tensor`): - The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step. + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be generated from + vae encoder and updated in input step.) timesteps (`Tensor`): The timesteps to use for the denoising process. Can be generated in set_timesteps step. processed_mask_image (`Tensor`): The processed mask to use for the inpainting process. - height (`None`): - TODO: Add description. - width (`None`): - TODO: Add description. - dtype (`None`): - TODO: Add description. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. Outputs: initial_noise (`Tensor`): The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. mask (`Tensor`): The mask to use for the inpainting process. """ @@ -445,26 +453,22 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.). Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. height (`int`, *optional*): @@ -479,7 +483,7 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): Custom sigmas for the denoising process. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -523,34 +527,30 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task. Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. - processed_mask_image (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. generator (`Generator`, *optional*): @@ -563,7 +563,7 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): Strength for img2img/inpainting. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -609,32 +609,28 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task. Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. generator (`Generator`, *optional*): @@ -647,7 +643,7 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): Strength for img2img/inpainting. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -693,30 +689,25 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.). Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - controlnet (`QwenImageControlNetModel`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. - control_image_latents (`None`): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): @@ -735,12 +726,9 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): When to stop applying ControlNet. controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): Scale for ControlNet conditioning. - **denoiser_input_fields (`None`, *optional*): - All conditional model inputs for the denoiser. It should contain prompt_embeds/negative_prompt_embeds, - txt_seq_lens/negative_txt_seq_lens. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -788,38 +776,33 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task. Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - controlnet (`QwenImageControlNetModel`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. - processed_mask_image (`None`, *optional*): - TODO: Add description. - control_image_latents (`None`): - TODO: Add description. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step. latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. generator (`Generator`, *optional*): @@ -836,12 +819,9 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): When to stop applying ControlNet. controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): Scale for ControlNet conditioning. - **denoiser_input_fields (`None`, *optional*): - All conditional model inputs for the denoiser. It should contain prompt_embeds/negative_prompt_embeds, - txt_seq_lens/negative_txt_seq_lens. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -891,36 +871,31 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task. Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - controlnet (`QwenImageControlNetModel`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. - control_image_latents (`None`): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step. latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. generator (`Generator`, *optional*): @@ -937,12 +912,9 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): When to stop applying ControlNet. controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): Scale for ControlNet conditioning. - **denoiser_input_fields (`None`, *optional*): - All conditional model inputs for the denoiser. It should contain prompt_embeds/negative_prompt_embeds, - txt_seq_lens/negative_txt_seq_lens. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -1058,20 +1030,18 @@ class QwenImageDecodeStep(SequentialPipelineBlocks): Decode step that decodes the latents to images and postprocess the generated image. Components: - vae (`AutoencoderKLQwenImage`) - image_processor (`VaeImageProcessor`) Inputs: latents (`Tensor`): - The latents to decode, can be generated in the denoise step + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'. Outputs: images (`List`): - Generated images. + Generated images. (tensor output of the vae decoder.) """ model_name = "qwenimage" @@ -1090,22 +1060,20 @@ class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image. Components: - vae (`AutoencoderKLQwenImage`) - image_mask_processor (`InpaintProcessor`) Inputs: latents (`Tensor`): - The latents to decode, can be generated in the denoise step + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'. - mask_overlay_kwargs (`None`, *optional*): - TODO: Add description. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep. Outputs: images (`List`): - Generated images. + Generated images. (tensor output of the vae decoder.) """ model_name = "qwenimage" @@ -1157,42 +1125,18 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks): - for text-to-image generation, all you need to provide is `prompt` Components: - text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use - tokenizer (`Qwen2Tokenizer`): The tokenizer to use - guider (`ClassifierFreeGuidance`) - image_mask_processor (`InpaintProcessor`) - vae (`AutoencoderKLQwenImage`) - image_processor (`VaeImageProcessor`) - controlnet (`QwenImageControlNetModel`) - control_image_processor (`VaeImageProcessor`) - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - transformer (`QwenImageTransformer2DModel`) - Configs: - - prompt_template_encode (default: <|im_start|>system - Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|> - <|im_start|>user - {}<|im_end|> - <|im_start|>assistant - ) - - prompt_template_encode_start_idx (default: 34) - - tokenizer_max_length (default: 1024) - Inputs: prompt (`str`, *optional*): The prompt or prompts to guide image generation. @@ -1202,8 +1146,8 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks): Maximum sequence length for prompt encoding. mask_image (`Image`, *optional*): Mask image for inpainting. - image (`Image`, *optional*): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): @@ -1216,14 +1160,14 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks): Control image for ControlNet conditioning. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. latents (`Tensor`): Pre-generated noisy latents for image generation. num_inference_steps (`int`): @@ -1232,29 +1176,26 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks): Custom sigmas for the denoising process. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. - image_latents (`None`, *optional*): - TODO: Add description. - processed_mask_image (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image strength (`float`, *optional*, defaults to 0.9): Strength for img2img/inpainting. - control_image_latents (`None`, *optional*): - TODO: Add description. + control_image_latents (`Tensor`, *optional*): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step. control_guidance_start (`float`, *optional*, defaults to 0.0): When to start applying ControlNet. control_guidance_end (`float`, *optional*, defaults to 1.0): When to stop applying ControlNet. controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): Scale for ControlNet conditioning. - **denoiser_input_fields (`None`, *optional*): - All conditional model inputs for the denoiser. It should contain prompt_embeds/negative_prompt_embeds, - txt_seq_lens/negative_txt_seq_lens. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'. - mask_overlay_kwargs (`None`, *optional*): - TODO: Add description. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep. Outputs: images (`List`): diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py index 158763ce91..0c1fa00842 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py @@ -63,29 +63,14 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): QwenImage-Edit VL encoder step that encode the image and text prompts together. Components: - image_resize_processor (`VaeImageProcessor`) - text_encoder (`Qwen2_5_VLForConditionalGeneration`) - processor (`Qwen2VLProcessor`) - guider (`ClassifierFreeGuidance`) - Configs: - - prompt_template_encode (default: <|im_start|>system - Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|> - <|im_start|>user - <|vision_start|><|image_pad|><|vision_end|>{}<|im_end|> - <|im_start|>assistant - ) - - prompt_template_encode_start_idx (default: 64) - Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. prompt (`str`): The prompt or prompts to guide image generation. negative_prompt (`str`, *optional*): @@ -95,13 +80,13 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): resized_image (`List`): The resized images prompt_embeds (`Tensor`): - The prompt embeddings + The prompt embeddings. prompt_embeds_mask (`Tensor`): - The encoder attention mask + The encoder attention mask. negative_prompt_embeds (`Tensor`): - The negative prompt embeddings + The negative prompt embeddings. negative_prompt_embeds_mask (`Tensor`): - The negative prompt embeddings mask + The negative prompt embeddings mask. """ model_name = "qwenimage-edit" @@ -128,26 +113,23 @@ class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): Vae encoder step that encode the image inputs into their latent representations. Components: - image_resize_processor (`VaeImageProcessor`) - image_processor (`VaeImageProcessor`) - vae (`AutoencoderKLQwenImage`) Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. generator (`Generator`, *optional*): Torch generator for deterministic generation. Outputs: resized_image (`List`): The resized images - processed_image (`None`): - TODO: Add description. + processed_image (`Tensor`): + The processed image image_latents (`Tensor`): - The latents representing the reference image(s). Single tensor or list depending on input. + The latent representation of the input image. """ model_name = "qwenimage-edit" @@ -173,16 +155,13 @@ class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): - create image latents. Components: - image_resize_processor (`VaeImageProcessor`) - image_mask_processor (`InpaintProcessor`) - vae (`AutoencoderKLQwenImage`) Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. mask_image (`Image`): Mask image for inpainting. padding_mask_crop (`int`, *optional*): @@ -193,14 +172,14 @@ class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): Outputs: resized_image (`List`): The resized images - processed_image (`None`): - TODO: Add description. - processed_mask_image (`None`): - TODO: Add description. + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image mask_overlay_kwargs (`Dict`): The kwargs for the postprocess step to apply the mask overlay image_latents (`Tensor`): - The latents representing the reference image(s). Single tensor or list depending on input. + The latent representation of the input image. """ model_name = "qwenimage-edit" @@ -252,36 +231,50 @@ class QwenImageEditInputStep(SequentialPipelineBlocks): - update height/width based `image_latents`, patchify `image_latents`. Components: - pachifier (`QwenImagePachifier`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. Outputs: batch_size (`int`): - Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + The batch size of the prompt embeddings dtype (`dtype`): - Data type of model tensor inputs (determined by `prompt_embeds`) + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) image_height (`int`): The image height calculated from the image latents dimension image_width (`int`): The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) """ model_name = "qwenimage-edit" @@ -308,38 +301,54 @@ class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): - update height/width based `image_latents`, patchify `image_latents`. Components: - pachifier (`QwenImagePachifier`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. - processed_mask_image (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image Outputs: batch_size (`int`): - Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + The batch size of the prompt embeddings dtype (`dtype`): - Data type of model tensor inputs (determined by `prompt_embeds`) + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) image_height (`int`): The image height calculated from the image latents dimension image_width (`int`): The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) """ model_name = "qwenimage-edit" @@ -368,30 +377,31 @@ class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): - Create the patchified latents `mask` based on the processed mask image. Components: - scheduler (`FlowMatchEulerDiscreteScheduler`) - pachifier (`QwenImagePachifier`) Inputs: latents (`Tensor`): The initial random noised, can be generated in prepare latent step. image_latents (`Tensor`): - The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step. + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be generated from + vae encoder and updated in input step.) timesteps (`Tensor`): The timesteps to use for the denoising process. Can be generated in set_timesteps step. processed_mask_image (`Tensor`): The processed mask to use for the inpainting process. - height (`None`): - TODO: Add description. - width (`None`): - TODO: Add description. - dtype (`None`): - TODO: Add description. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. Outputs: initial_noise (`Tensor`): The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. mask (`Tensor`): The mask to use for the inpainting process. """ @@ -416,32 +426,28 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): Core denoising workflow for QwenImage-Edit edit (img2img) task. Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. generator (`Generator`, *optional*): @@ -452,7 +458,7 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): Custom sigmas for the denoising process. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -496,34 +502,30 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): Core denoising workflow for QwenImage-Edit edit inpaint task. Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. - processed_mask_image (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. generator (`Generator`, *optional*): @@ -536,7 +538,7 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): Strength for img2img/inpainting. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -621,20 +623,18 @@ class QwenImageEditDecodeStep(SequentialPipelineBlocks): Decode step that decodes the latents to images and postprocess the generated image. Components: - vae (`AutoencoderKLQwenImage`) - image_processor (`VaeImageProcessor`) Inputs: latents (`Tensor`): - The latents to decode, can be generated in the denoise step + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'. Outputs: images (`List`): - Generated images. + Generated images. (tensor output of the vae decoder.) """ model_name = "qwenimage-edit" @@ -653,22 +653,20 @@ class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks): Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image. Components: - vae (`AutoencoderKLQwenImage`) - image_mask_processor (`InpaintProcessor`) Inputs: latents (`Tensor`): - The latents to decode, can be generated in the denoise step + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'. - mask_overlay_kwargs (`None`, *optional*): - TODO: Add description. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep. Outputs: images (`List`): - Generated images. + Generated images. (tensor output of the vae decoder.) """ model_name = "qwenimage-edit" @@ -724,41 +722,20 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks): - for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` Components: - image_resize_processor (`VaeImageProcessor`) - text_encoder (`Qwen2_5_VLForConditionalGeneration`) - processor (`Qwen2VLProcessor`) - guider (`ClassifierFreeGuidance`) - image_mask_processor (`InpaintProcessor`) - vae (`AutoencoderKLQwenImage`) - image_processor (`VaeImageProcessor`) - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - transformer (`QwenImageTransformer2DModel`) - Configs: - - prompt_template_encode (default: <|im_start|>system - Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|> - <|im_start|>user - <|vision_start|><|image_pad|><|vision_end|>{}<|im_end|> - <|im_start|>assistant - ) - - prompt_template_encode_start_idx (default: 64) - Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. prompt (`str`): The prompt or prompts to guide image generation. negative_prompt (`str`, *optional*): @@ -775,10 +752,10 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks): The height in pixels of the generated image. width (`int`): The width in pixels of the generated image. - image_latents (`None`): - TODO: Add description. - processed_mask_image (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image latents (`Tensor`): Pre-generated noisy latents for image generation. num_inference_steps (`int`): @@ -789,12 +766,12 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks): Strength for img2img/inpainting. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'. - mask_overlay_kwargs (`None`, *optional*): - TODO: Add description. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep. Outputs: images (`List`): diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py index a16dee1c75..726c000f4b 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py @@ -55,47 +55,32 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together. Components: - image_resize_processor (`VaeImageProcessor`) - text_encoder (`Qwen2_5_VLForConditionalGeneration`) - processor (`Qwen2VLProcessor`) - guider (`ClassifierFreeGuidance`) - Configs: - - prompt_template_encode (default: <|im_start|>system - Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|> - <|im_start|>user - {}<|im_end|> - <|im_start|>assistant - ) - - img_template_encode (default: Picture {}: <|vision_start|><|image_pad|><|vision_end|>) - - prompt_template_encode_start_idx (default: 64) - Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. prompt (`str`): The prompt or prompts to guide image generation. negative_prompt (`str`, *optional*): The prompt or prompts not to guide the image generation. Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding resized_cond_image (`List`): - The resized images + Images resized to 384x384 target area for VL text encoding prompt_embeds (`Tensor`): - The prompt embeddings + The prompt embeddings. prompt_embeds_mask (`Tensor`): - The encoder attention mask + The encoder attention mask. negative_prompt_embeds (`Tensor`): - The negative prompt embeddings + The negative prompt embeddings. negative_prompt_embeds_mask (`Tensor`): - The negative prompt embeddings mask + The negative prompt embeddings mask. """ model_name = "qwenimage-edit-plus" @@ -122,26 +107,25 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): Each image is resized independently based on its own aspect ratio to 1024x1024 target area. Components: - image_resize_processor (`VaeImageProcessor`) - image_processor (`VaeImageProcessor`) - vae (`AutoencoderKLQwenImage`) Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. generator (`Generator`, *optional*): Torch generator for deterministic generation. Outputs: resized_image (`List`): - The resized images - processed_image (`None`): - TODO: Add description. + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + processed_image (`Tensor`): + The processed image image_latents (`Tensor`): - The latents representing the reference image(s). Single tensor or list depending on input. + The latent representation of the input image. """ model_name = "qwenimage-edit-plus" @@ -176,36 +160,50 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks): - Defaults height/width from last image in the list. Components: - pachifier (`QwenImagePachifier`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. Outputs: batch_size (`int`): - Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + The batch size of the prompt embeddings dtype (`dtype`): - Data type of model tensor inputs (determined by `prompt_embeds`) + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) image_height (`List`): The image heights calculated from the image latents dimension image_width (`List`): The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) """ model_name = "qwenimage-edit-plus" @@ -233,32 +231,28 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): Core denoising workflow for QwenImage-Edit Plus edit (img2img) task. Components: - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. - image_latents (`None`, *optional*): - TODO: Add description. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. generator (`Generator`, *optional*): @@ -269,7 +263,7 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): Custom sigmas for the denoising process. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -317,20 +311,18 @@ class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks): Decode step that decodes the latents to images and postprocesses the generated image. Components: - vae (`AutoencoderKLQwenImage`) - image_processor (`VaeImageProcessor`) Inputs: latents (`Tensor`): - The latents to decode, can be generated in the denoise step + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'. Outputs: images (`List`): - Generated images. + Generated images. (tensor output of the vae decoder.) """ model_name = "qwenimage-edit-plus" @@ -365,41 +357,19 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): - VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area. Components: - image_resize_processor (`VaeImageProcessor`) - text_encoder (`Qwen2_5_VLForConditionalGeneration`) - processor (`Qwen2VLProcessor`) - guider (`ClassifierFreeGuidance`) - image_processor (`VaeImageProcessor`) - vae (`AutoencoderKLQwenImage`) - pachifier (`QwenImagePachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - transformer (`QwenImageTransformer2DModel`) - Configs: - - prompt_template_encode (default: <|im_start|>system - Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|> - <|im_start|>user - {}<|im_end|> - <|im_start|>assistant - ) - - img_template_encode (default: Picture {}: <|vision_start|><|image_pad|><|vision_end|>) - - prompt_template_encode_start_idx (default: 64) - Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. prompt (`str`): The prompt or prompts to guide image generation. negative_prompt (`str`, *optional*): @@ -420,7 +390,7 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): Custom sigmas for the denoising process. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'. diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py index 2471750f2e..37a06e9af2 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py @@ -56,73 +56,19 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided. Components: - image_resize_processor (`VaeImageProcessor`) - text_encoder (`Qwen2_5_VLForConditionalGeneration`) - processor (`Qwen2VLProcessor`) - tokenizer (`Qwen2Tokenizer`): The tokenizer to use - guider (`ClassifierFreeGuidance`) - Configs: - - image_caption_prompt_en (default: <|im_start|>system - You are a helpful assistant.<|im_end|> - <|im_start|>user - # Image Annotator - You are a professional image annotator. Please write an image caption based on the input image: - 1. Write the caption using natural, descriptive language without structured formats or rich text. - 2. Enrich caption details by including: - - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on - - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, attachment relations, action relations, comparative relations, causal relations, and so on - - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on - - Identify the text clearly visible in the image, without translation or explanation, and highlight it in the caption with quotation marks - 3. Maintain authenticity and accuracy: - - Avoid generalizations - - Describe all visible information in the image, while do not add information not explicitly shown in the image - <|vision_start|><|image_pad|><|vision_end|><|im_end|> - <|im_start|>assistant - ) - - image_caption_prompt_cn (default: <|im_start|>system - You are a helpful assistant.<|im_end|> - <|im_start|>user - # 图像标注器 - 你是一个专业的图像标注器。请基于输入图像,撰写图注: - 1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。 - 2. 通过加入以下内容,丰富图注细节: - - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等 - - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等 - - 环境细节:例如天气、光照、颜色、纹理、气氛等 - - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调 - 3. 保持真实性与准确性: - - 不要使用笼统的描述 - - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容 - <|vision_start|><|image_pad|><|vision_end|><|im_end|> - <|im_start|>assistant - ) - - prompt_template_encode (default: <|im_start|>system - Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|> - <|im_start|>user - {}<|im_end|> - <|im_start|>assistant - ) - - prompt_template_encode_start_idx (default: 34) - - tokenizer_max_length (default: 1024) - Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. resolution (`int`, *optional*, defaults to 640): The target area to resize the image to, can be 1024 or 640 prompt (`str`, *optional*): - The prompt to encode + The prompt or prompts to guide image generation. use_en_prompt (`bool`, *optional*, defaults to False): Whether to use English prompt template negative_prompt (`str`, *optional*): @@ -133,14 +79,16 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): Outputs: resized_image (`List`): The resized images + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption prompt_embeds (`Tensor`): - The prompt embeddings + The prompt embeddings. prompt_embeds_mask (`Tensor`): - The encoder attention mask + The encoder attention mask. negative_prompt_embeds (`Tensor`): - The negative prompt embeddings + The negative prompt embeddings. negative_prompt_embeds_mask (`Tensor`): - The negative prompt embeddings mask + The negative prompt embeddings mask. """ model_name = "qwenimage-layered" @@ -168,16 +116,13 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): Vae encoder step that encode the image inputs into their latent representations. Components: - image_resize_processor (`VaeImageProcessor`) - image_processor (`VaeImageProcessor`) - vae (`AutoencoderKLQwenImage`) Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. resolution (`int`, *optional*, defaults to 640): The target area to resize the image to, can be 1024 or 640 generator (`Generator`, *optional*): @@ -186,10 +131,10 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): Outputs: resized_image (`List`): The resized images - processed_image (`None`): - TODO: Add description. + processed_image (`Tensor`): + The processed image image_latents (`Tensor`): - The latents representing the reference image(s). Single tensor or list depending on input. + The latent representation of the input image. """ model_name = "qwenimage-layered" @@ -220,36 +165,46 @@ class QwenImageLayeredInputStep(SequentialPipelineBlocks): - update height/width based `image_latents`, patchify `image_latents`. Components: - pachifier (`QwenImageLayeredPachifier`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. - image_latents (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. Outputs: batch_size (`int`): - Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + The batch size of the prompt embeddings dtype (`dtype`): - Data type of model tensor inputs (determined by `prompt_embeds`) + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) image_height (`int`): The image height calculated from the image latents dimension image_width (`int`): The image width calculated from the image latents dimension height (`int`): - The height of the image output + if not provided, updated to image height width (`int`): - The width of the image output + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified with layered + pachifier and batch-expanded) """ model_name = "qwenimage-layered" @@ -275,28 +230,24 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): Core denoising workflow for QwenImage-Layered img2img task. Components: - pachifier (`QwenImageLayeredPachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - guider (`ClassifierFreeGuidance`) - transformer (`QwenImageTransformer2DModel`) Inputs: num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - prompt_embeds (`None`): - TODO: Add description. - prompt_embeds_mask (`None`): - TODO: Add description. - negative_prompt_embeds (`None`, *optional*): - TODO: Add description. - negative_prompt_embeds_mask (`None`, *optional*): - TODO: Add description. - image_latents (`None`, *optional*): - TODO: Add description. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. latents (`Tensor`, *optional*): Pre-generated noisy latents for image generation. layers (`int`, *optional*, defaults to 4): @@ -309,7 +260,7 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): Custom sigmas for the denoising process. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. Outputs: @@ -366,83 +317,24 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): Auto Modular pipeline for layered denoising tasks using QwenImage-Layered. Components: - image_resize_processor (`VaeImageProcessor`) - text_encoder (`Qwen2_5_VLForConditionalGeneration`) - processor (`Qwen2VLProcessor`) - tokenizer (`Qwen2Tokenizer`): The tokenizer to use - guider (`ClassifierFreeGuidance`) - image_processor (`VaeImageProcessor`) - vae (`AutoencoderKLQwenImage`) - pachifier (`QwenImageLayeredPachifier`) - scheduler (`FlowMatchEulerDiscreteScheduler`) - transformer (`QwenImageTransformer2DModel`) - Configs: - - image_caption_prompt_en (default: <|im_start|>system - You are a helpful assistant.<|im_end|> - <|im_start|>user - # Image Annotator - You are a professional image annotator. Please write an image caption based on the input image: - 1. Write the caption using natural, descriptive language without structured formats or rich text. - 2. Enrich caption details by including: - - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on - - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, attachment relations, action relations, comparative relations, causal relations, and so on - - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on - - Identify the text clearly visible in the image, without translation or explanation, and highlight it in the caption with quotation marks - 3. Maintain authenticity and accuracy: - - Avoid generalizations - - Describe all visible information in the image, while do not add information not explicitly shown in the image - <|vision_start|><|image_pad|><|vision_end|><|im_end|> - <|im_start|>assistant - ) - - image_caption_prompt_cn (default: <|im_start|>system - You are a helpful assistant.<|im_end|> - <|im_start|>user - # 图像标注器 - 你是一个专业的图像标注器。请基于输入图像,撰写图注: - 1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。 - 2. 通过加入以下内容,丰富图注细节: - - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等 - - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等 - - 环境细节:例如天气、光照、颜色、纹理、气氛等 - - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调 - 3. 保持真实性与准确性: - - 不要使用笼统的描述 - - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容 - <|vision_start|><|image_pad|><|vision_end|><|im_end|> - <|im_start|>assistant - ) - - prompt_template_encode (default: <|im_start|>system - Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|> - <|im_start|>user - {}<|im_end|> - <|im_start|>assistant - ) - - prompt_template_encode_start_idx (default: 34) - - tokenizer_max_length (default: 1024) - Inputs: - image (`Image`): - Input image for img2img, editing, or conditioning. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. resolution (`int`, *optional*, defaults to 640): The target area to resize the image to, can be 1024 or 640 prompt (`str`, *optional*): - The prompt to encode + The prompt or prompts to guide image generation. use_en_prompt (`bool`, *optional*, defaults to False): Whether to use English prompt template negative_prompt (`str`, *optional*): @@ -463,7 +355,7 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): Custom sigmas for the denoising process. attention_kwargs (`Dict`, *optional*): Additional kwargs for attention processors. - denoiser_input_fields (`Tensor`, *optional*): + **denoiser_input_fields (`None`, *optional*): conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. output_type (`str`, *optional*, defaults to pil): Output format: 'pil', 'np', 'pt'.