From f7439c30c9ed7c1aa921459edfd535ecc7c5aa61 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 9 Dec 2025 08:08:41 -1000 Subject: [PATCH] [Modular]z-image (#12808) * initiL * up up * fix: z_image -> z-image * style * copy * fix more * some docstring fix --- src/diffusers/__init__.py | 4 + src/diffusers/modular_pipelines/__init__.py | 5 + .../modular_pipelines/modular_pipeline.py | 1 + .../modular_pipelines/wan/encoders.py | 6 +- .../modular_pipelines/z_image/__init__.py | 57 ++ .../z_image/before_denoise.py | 621 ++++++++++++++++++ .../modular_pipelines/z_image/decoders.py | 91 +++ .../modular_pipelines/z_image/denoise.py | 310 +++++++++ .../modular_pipelines/z_image/encoders.py | 344 ++++++++++ .../z_image/modular_blocks.py | 191 ++++++ .../z_image/modular_pipeline.py | 72 ++ .../dummy_torch_and_transformers_objects.py | 30 + 12 files changed, 1730 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/modular_pipelines/z_image/__init__.py create mode 100644 src/diffusers/modular_pipelines/z_image/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/z_image/decoders.py create mode 100644 src/diffusers/modular_pipelines/z_image/denoise.py create mode 100644 src/diffusers/modular_pipelines/z_image/encoders.py create mode 100644 src/diffusers/modular_pipelines/z_image/modular_blocks.py create mode 100644 src/diffusers/modular_pipelines/z_image/modular_pipeline.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6df4ad4894..e69d334fdb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -419,6 +419,8 @@ else: "Wan22AutoBlocks", "WanAutoBlocks", "WanModularPipeline", + "ZImageAutoBlocks", + "ZImageModularPipeline", ] ) _import_structure["pipelines"].extend( @@ -1124,6 +1126,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline, + ZImageAutoBlocks, + ZImageModularPipeline, ) from .pipelines import ( AllegroPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 252b9f33df..dea9da0269 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -60,6 +60,10 @@ else: "QwenImageEditPlusModularPipeline", "QwenImageEditPlusAutoBlocks", ] + _import_structure["z_image"] = [ + "ZImageAutoBlocks", + "ZImageModularPipeline", + ] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -91,6 +95,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline + from .z_image import ZImageAutoBlocks, ZImageModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index a6336de71a..bba89e6121 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -61,6 +61,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict( ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), + ("z-image", "ZImageModularPipeline"), ] ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index dc49df8eab..4fd69c6ca6 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -530,6 +530,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks): device = components._execution_device dtype = torch.float32 + vae_dtype = components.vae.dtype height = block_state.height or components.default_height width = block_state.width or components.default_width @@ -555,7 +556,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks): vae=components.vae, generator=block_state.generator, device=device, - dtype=dtype, + dtype=vae_dtype, latent_channels=components.num_channels_latents, ) @@ -627,6 +628,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks): device = components._execution_device dtype = torch.float32 + vae_dtype = components.vae.dtype height = block_state.height or components.default_height width = block_state.width or components.default_width @@ -659,7 +661,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks): vae=components.vae, generator=block_state.generator, device=device, - dtype=dtype, + dtype=vae_dtype, latent_channels=components.num_channels_latents, ) diff --git a/src/diffusers/modular_pipelines/z_image/__init__.py b/src/diffusers/modular_pipelines/z_image/__init__.py new file mode 100644 index 0000000000..c8a8c14396 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["decoders"] = ["ZImageVaeDecoderStep"] + _import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "ZImageAutoBlocks", + ] + _import_structure["modular_pipeline"] = ["ZImageModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .decoders import ZImageVaeDecoderStep + from .encoders import ZImageTextEncoderStep + from .modular_blocks import ( + ALL_BLOCKS, + ZImageAutoBlocks, + ) + from .modular_pipeline import ZImageModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/z_image/before_denoise.py b/src/diffusers/modular_pipelines/z_image/before_denoise.py new file mode 100644 index 0000000000..35ea768f12 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/before_denoise.py @@ -0,0 +1,621 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch + +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_spatial: int) -> Tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent spatial dimensions to image spatial dimensions by multiplying the latent height/width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 dimensions. + Expected shapes: [batch, channels, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress image spatial dimension. + By default, it is 16 + Returns: + Tuple[int, int]: The calculated image dimensions as (height, width) + """ + latent_height, latent_width = latents.shape[2:] + height = latent_height * vae_scale_factor_spatial // 2 + width = latent_width * vae_scale_factor_spatial // 2 + + return height, width + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageTextInputStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=List[torch.Tensor], + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=List[torch.Tensor], + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `transformer.dtype`)", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if not isinstance(block_state.prompt_embeds, list): + raise ValueError( + f"`prompt_embeds` must be a list when passed directly, but got {type(block_state.prompt_embeds)}." + ) + if not isinstance(block_state.negative_prompt_embeds, list): + raise ValueError( + f"`negative_prompt_embeds` must be a list when passed directly, but got {type(block_state.negative_prompt_embeds)}." + ) + if len(block_state.prompt_embeds) != len(block_state.negative_prompt_embeds): + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same length when passed directly, but" + f" got: `prompt_embeds` {len(block_state.prompt_embeds)} != `negative_prompt_embeds`" + f" {len(block_state.negative_prompt_embeds)}." + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = len(block_state.prompt_embeds) + block_state.dtype = block_state.prompt_embeds[0].dtype + + if block_state.num_images_per_prompt > 1: + prompt_embeds = [pe for pe in block_state.prompt_embeds for _ in range(block_state.num_images_per_prompt)] + block_state.prompt_embeds = prompt_embeds + + if block_state.negative_prompt_embeds is not None: + negative_prompt_embeds = [ + npe for npe in block_state.negative_prompt_embeds for _ in range(block_state.num_images_per_prompt) + ] + block_state.negative_prompt_embeds = negative_prompt_embeds + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageAdditionalInputsStep(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + image_latent_inputs: List[str] = ["image_latents"], + additional_batch_inputs: List[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (List[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["image_latents"]. + additional_batch_inputs (List[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. + + Examples: + # Configure to process image_latents (default behavior) ZImageAdditionalInputsStep() + + # Configure to process multiple image latent inputs + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents", "control_image_latents"]) + + # Configure to process image latents and additional batch inputs ZImageAdditionalInputsStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["image_embeds"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_spatial) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + def check_inputs(self, components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + block_state.latents = self.prepare_latents( + components, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageSetTimestepsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference. Need to run after prepare latents step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("num_inference_steps", default=9), + InputParam("sigmas"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process" + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + latent_height, latent_width = block_state.latents.shape[2], block_state.latents.shape[3] + image_seq_len = (latent_height // 2) * (latent_width // 2) # sequence length after patchify + + mu = calculate_shift( + image_seq_len, + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + components.scheduler.sigma_min = 0.0 + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=block_state.sigmas, + mu=mu, + ) + + self.set_block_state(state, block_state) + return components, state + + +class ZImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference with strength. Need to run after set timesteps step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("timesteps", required=True), + InputParam("num_inference_steps", required=True), + InputParam("strength", default=0.6), + ] + + def check_inputs(self, components, block_state): + if block_state.strength < 0.0 or block_state.strength > 1.0: + raise ValueError(f"Strength must be between 0.0 and 1.0, but got {block_state.strength}") + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + init_timestep = min(block_state.num_inference_steps * block_state.strength, block_state.num_inference_steps) + + t_start = int(max(block_state.num_inference_steps - init_timestep, 0)) + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + block_state.timesteps = timesteps + block_state.num_inference_steps = block_state.num_inference_steps - t_start + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentswithImageStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "step that prepares the latents with image condition, need to run after set timesteps and prepare latents step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("image_latents", required=True), + InputParam("timesteps", required=True), + ] + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/decoders.py b/src/diffusers/modular_pipelines/z_image/decoders.py new file mode 100644 index 0000000000..cdb6a2e5ea --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/decoders.py @@ -0,0 +1,91 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageVaeDecoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + ), + InputParam( + name="output_type", + default="pil", + type_hint=str, + description="The type of the output images, can be 'pil', 'np', 'pt'", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + latents = block_state.latents.to(vae_dtype) + latents = latents / components.vae.config.scaling_factor + components.vae.config.shift_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/denoise.py b/src/diffusers/modular_pipelines/z_image/denoise.py new file mode 100644 index 0000000000..ec815f77ad --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/denoise.py @@ -0,0 +1,310 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents = block_state.latents.unsqueeze(2).to( + block_state.dtype + ) # [batch_size, num_channels, 1, height, width] + block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width] + + timestep = t.expand(latents.shape[0]).to(block_state.dtype) + timestep = (1000 - timestep) / 1000 + block_state.timestep = timestep + return components, block_state + + +class ZImageLoopDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + guider_input_fields: Dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + inputs = [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + guider_input_names = [] + uncond_guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.append(value[0]) + uncond_guider_input_names.append(value[1]) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True)) + for name in uncond_guider_input_names: + inputs.append(InputParam(name=name)) + return inputs + + @torch.no_grad() + def __call__( + self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + + def _convert_dtype(v, dtype): + if isinstance(v, torch.Tensor): + return v.to(dtype) + elif isinstance(v, list): + return [_convert_dtype(t, dtype) for t in v] + return v + + cond_kwargs = { + k: _convert_dtype(v, block_state.dtype) + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + model_out_list = components.transformer( + x=block_state.latent_model_input, + t=block_state.timestep, + return_dict=False, + **cond_kwargs, + )[0] + noise_pred = torch.stack(model_out_list, dim=0).squeeze(2) + guider_state_batch.noise_pred = -noise_pred + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class ZImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def loop_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageDenoiseStep(ZImageDenoiseLoopWrapper): + block_classes = [ + ZImageLoopBeforeDenoiser, + ZImageLoopDenoiser( + guider_input_fields={ + "cap_feats": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + ZImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `ZImageLoopBeforeDenoiser`\n" + " - `ZImageLoopDenoiser`\n" + " - `ZImageLoopAfterDenoiser`\n" + "This block supports text-to-image and image-to-image tasks for Z-Image." + ) diff --git a/src/diffusers/modular_pipelines/z_image/encoders.py b/src/diffusers/modular_pipelines/z_image/encoders.py new file mode 100644 index 0000000000..f5769fe2de --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/encoders.py @@ -0,0 +1,344 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import PIL +import torch +from transformers import Qwen2Tokenizer, Qwen3Model + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import is_ftfy_available, logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +if is_ftfy_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_qwen_prompt_embeds( + text_encoder: Qwen3Model, + tokenizer: Qwen2Tokenizer, + prompt: Union[str, List[str]], + device: torch.device, + max_sequence_length: int = 512, +) -> List[torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + prompt_embeds_list = [] + + for i in range(len(prompt_embeds)): + prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]]) + + return prompt_embeds_list + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + image_tensor: torch.Tensor, + vae: AutoencoderKL, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, +): + if not isinstance(image_tensor, torch.Tensor): + raise ValueError(f"Expected image_tensor to be a tensor, got {type(image_tensor)}.") + + if isinstance(generator, list) and len(generator) != image_tensor.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image_tensor.shape[0]}." + ) + + image_tensor = image_tensor.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image_tensor[i : i + 1]), generator=generator[i]) + for i in range(image_tensor.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image_tensor), generator=generator) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +class ZImageTextEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3Model), + ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("max_sequence_length", default=512), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=List[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=List[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: Optional[torch.device] = None, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_prompt_embeds = None + if prepare_unconditional_embeds: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class ZImageVaeImageEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on image to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = block_state.image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + image_tensor = components.image_processor.preprocess( + image, height=block_state.height, width=block_state.width + ).to(device=device, dtype=dtype) + + block_state.image_latents = encode_vae_image( + image_tensor=image_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/modular_blocks.py b/src/diffusers/modular_pipelines/z_image/modular_blocks.py new file mode 100644 index 0000000000..a7c520301a --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/modular_blocks.py @@ -0,0 +1,191 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + ZImageAdditionalInputsStep, + ZImagePrepareLatentsStep, + ZImagePrepareLatentswithImageStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImageTextInputStep, +) +from .decoders import ZImageVaeDecoderStep +from .denoise import ( + ZImageDenoiseStep, +) +from .encoders import ( + ZImageTextEncoderStep, + ZImageVaeImageEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# z-image +# text2image +class ZImageCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + ZImageTextInputStep, + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageDenoiseStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `ZImagePrepareLatentsStep` is used to prepare the latents\n" + + " - `ZImageSetTimestepsStep` is used to set the timesteps\n" + + " - `ZImageDenoiseStep` is used to denoise the latents\n" + ) + + +# z-image: image2image +## denoise +class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + ZImageTextInputStep, + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]), + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImagePrepareLatentswithImageStep, + ZImageDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "prepare_latents", + "set_timesteps", + "set_timesteps_with_strength", + "prepare_latents_with_image", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `ZImageAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `ZImagePrepareLatentsStep` is used to prepare the latents\n" + + " - `ZImageSetTimestepsStep` is used to set the timesteps\n" + + " - `ZImageSetTimestepsWithStrengthStep` is used to set the timesteps with strength\n" + + " - `ZImagePrepareLatentswithImageStep` is used to prepare the latents with image\n" + + " - `ZImageDenoiseStep` is used to denoise the latents\n" + ) + + +## auto blocks +class ZImageAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + ZImageImage2ImageCoreDenoiseStep, + ZImageCoreDenoiseStep, + ] + block_names = ["image2image", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2image and image2image tasks." + " - `ZImageCoreDenoiseStep` (text2image) for text2image tasks." + " - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks." + + " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n" + + " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n" + ) + + +class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks): + block_classes = [ZImageVaeImageEncoderStep] + block_names = ["vae_image_encoder"] + block_trigger_inputs = ["image"] + + @property + def description(self) -> str: + return "Vae Image Encoder step that encode the image to generate the image latents" + +"This is an auto pipeline block that works for image2image tasks." + +" - `ZImageVaeImageEncoderStep` is used when `image` is provided." + +" - if `image` is not provided, step will be skipped." + + +class ZImageAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + ZImageTextEncoderStep, + ZImageAutoVaeImageEncoderStep, + ZImageAutoDenoiseStep, + ZImageVaeDecoderStep, + ] + block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"] + + @property + def description(self) -> str: + return "Auto Modular pipeline for text-to-image and image-to-image using ZImage.\n" + +" - for text-to-image generation, all you need to provide is `prompt`\n" + +" - for image-to-image generation, you need to provide `image`\n" + +" - if `image` is not provided, step will be skipped." + + +# presets +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("input", ZImageTextInputStep), + ("prepare_latents", ZImagePrepareLatentsStep), + ("set_timesteps", ZImageSetTimestepsStep), + ("denoise", ZImageDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + +IMAGE2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("vae_image_encoder", ZImageVaeImageEncoderStep), + ("input", ZImageTextInputStep), + ("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])), + ("prepare_latents", ZImagePrepareLatentsStep), + ("set_timesteps", ZImageSetTimestepsStep), + ("set_timesteps_with_strength", ZImageSetTimestepsWithStrengthStep), + ("prepare_latents_with_image", ZImagePrepareLatentswithImageStep), + ("denoise", ZImageDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("vae_image_encoder", ZImageAutoVaeImageEncoderStep), + ("denoise", ZImageAutoDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + +ALL_BLOCKS = { + "text2image": TEXT2IMAGE_BLOCKS, + "image2image": IMAGE2IMAGE_BLOCKS, + "auto": AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/z_image/modular_pipeline.py b/src/diffusers/modular_pipelines/z_image/modular_pipeline.py new file mode 100644 index 0000000000..f1d8e53a36 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/modular_pipeline.py @@ -0,0 +1,72 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import ZImageLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageModularPipeline( + ModularPipeline, + ZImageLoraLoaderMixin, +): + """ + A ModularPipeline for Z-Image. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "ZImageAutoBlocks" + + @property + def default_height(self): + return 1024 + + @property + def default_width(self): + return 1024 + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor_spatial = 16 + if hasattr(self, "image_processor") and self.image_processor is not None: + vae_scale_factor_spatial = self.image_processor.config.vae_scale_factor + return vae_scale_factor_spatial + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 79a21d2ac6..da64742518 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -227,6 +227,36 @@ class WanModularPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class ZImageAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ZImageModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AllegroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]