1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Modular]z-image (#12808)

* initiL

* up up

* fix: z_image -> z-image

* style

* copy

* fix more

* some docstring fix
This commit is contained in:
YiYi Xu
2025-12-09 08:08:41 -10:00
committed by GitHub
parent 54fa0745c3
commit 07ea0786e8
12 changed files with 1730 additions and 2 deletions

View File

@@ -419,6 +419,8 @@ else:
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanModularPipeline",
"ZImageAutoBlocks",
"ZImageModularPipeline",
]
)
_import_structure["pipelines"].extend(
@@ -1124,6 +1126,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Wan22AutoBlocks,
WanAutoBlocks,
WanModularPipeline,
ZImageAutoBlocks,
ZImageModularPipeline,
)
from .pipelines import (
AllegroPipeline,

View File

@@ -60,6 +60,10 @@ else:
"QwenImageEditPlusModularPipeline",
"QwenImageEditPlusAutoBlocks",
]
_import_structure["z_image"] = [
"ZImageAutoBlocks",
"ZImageModularPipeline",
]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -91,6 +95,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
else:
import sys

View File

@@ -61,6 +61,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
("z-image", "ZImageModularPipeline"),
]
)

View File

@@ -530,6 +530,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
device = components._execution_device
dtype = torch.float32
vae_dtype = components.vae.dtype
height = block_state.height or components.default_height
width = block_state.width or components.default_width
@@ -555,7 +556,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
dtype=vae_dtype,
latent_channels=components.num_channels_latents,
)
@@ -627,6 +628,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
device = components._execution_device
dtype = torch.float32
vae_dtype = components.vae.dtype
height = block_state.height or components.default_height
width = block_state.width or components.default_width
@@ -659,7 +661,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
dtype=vae_dtype,
latent_channels=components.num_channels_latents,
)

View File

@@ -0,0 +1,57 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["decoders"] = ["ZImageVaeDecoderStep"]
_import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"ZImageAutoBlocks",
]
_import_structure["modular_pipeline"] = ["ZImageModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .decoders import ZImageVaeDecoderStep
from .encoders import ZImageTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
ZImageAutoBlocks,
)
from .modular_pipeline import ZImageModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,621 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
import torch
from ...models import ZImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import ZImageModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
# configuration of guider is.
def repeat_tensor_to_batch_size(
input_name: str,
input_tensor: torch.Tensor,
batch_size: int,
num_images_per_prompt: int = 1,
) -> torch.Tensor:
"""Repeat tensor elements to match the final batch size.
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
by repeating each element along dimension 0.
The input tensor must have batch size 1 or batch_size. The function will:
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
- If batch size equals batch_size: repeat each element num_images_per_prompt times
Args:
input_name (str): Name of the input tensor (used for error messages)
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
batch_size (int): The base batch size (number of prompts)
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
Returns:
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
Raises:
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
Examples:
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
[4, 3]
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
- shape: [4, 3]
"""
# make sure input is a tensor
if not isinstance(input_tensor, torch.Tensor):
raise ValueError(f"`{input_name}` must be a tensor")
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
if input_tensor.shape[0] == 1:
repeat_by = batch_size * num_images_per_prompt
elif input_tensor.shape[0] == batch_size:
repeat_by = num_images_per_prompt
else:
raise ValueError(
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
)
# expand the tensor to match the batch_size * num_images_per_prompt
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
return input_tensor
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_spatial: int) -> Tuple[int, int]:
"""Calculate image dimensions from latent tensor dimensions.
This function converts latent spatial dimensions to image spatial dimensions by multiplying the latent height/width
by the VAE scale factor.
Args:
latents (torch.Tensor): The latent tensor. Must have 4 dimensions.
Expected shapes: [batch, channels, height, width]
vae_scale_factor (int): The scale factor used by the VAE to compress image spatial dimension.
By default, it is 16
Returns:
Tuple[int, int]: The calculated image dimensions as (height, width)
"""
latent_height, latent_width = latents.shape[2:]
height = latent_height * vae_scale_factor_spatial // 2
width = latent_width * vae_scale_factor_spatial // 2
return height, width
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ZImageTextInputStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return (
"Input processing step that:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
"have a final batch_size of batch_size * num_images_per_prompt."
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", ZImageTransformer2DModel),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
type_hint=List[torch.Tensor],
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"negative_prompt_embeds",
type_hint=List[torch.Tensor],
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `transformer.dtype`)",
),
]
def check_inputs(self, components, block_state):
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
if not isinstance(block_state.prompt_embeds, list):
raise ValueError(
f"`prompt_embeds` must be a list when passed directly, but got {type(block_state.prompt_embeds)}."
)
if not isinstance(block_state.negative_prompt_embeds, list):
raise ValueError(
f"`negative_prompt_embeds` must be a list when passed directly, but got {type(block_state.negative_prompt_embeds)}."
)
if len(block_state.prompt_embeds) != len(block_state.negative_prompt_embeds):
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same length when passed directly, but"
f" got: `prompt_embeds` {len(block_state.prompt_embeds)} != `negative_prompt_embeds`"
f" {len(block_state.negative_prompt_embeds)}."
)
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
block_state.batch_size = len(block_state.prompt_embeds)
block_state.dtype = block_state.prompt_embeds[0].dtype
if block_state.num_images_per_prompt > 1:
prompt_embeds = [pe for pe in block_state.prompt_embeds for _ in range(block_state.num_images_per_prompt)]
block_state.prompt_embeds = prompt_embeds
if block_state.negative_prompt_embeds is not None:
negative_prompt_embeds = [
npe for npe in block_state.negative_prompt_embeds for _ in range(block_state.num_images_per_prompt)
]
block_state.negative_prompt_embeds = negative_prompt_embeds
self.set_block_state(state, block_state)
return components, state
class ZImageAdditionalInputsStep(ModularPipelineBlocks):
model_name = "z-image"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
1. For encoded image latents, use it update height/width if None, and expands batch size
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
This is a dynamic block that allows you to configure which inputs to process.
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
a single string or list of strings. Defaults to ["image_latents"].
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to [].
Examples:
# Configure to process image_latents (default behavior) ZImageAdditionalInputsStep()
# Configure to process multiple image latent inputs
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents", "control_image_latents"])
# Configure to process image latents and additional batch inputs ZImageAdditionalInputsStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["image_embeds"]
)
"""
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
# Functionality section
summary_section = (
"Input processing step that:\n"
" 1. For image latent inputs: Updates height/width if None, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
# Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
# Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
]
# Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
# Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate num_frames, height/width from latents
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_spatial)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
class ZImagePrepareLatentsStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
)
]
def check_inputs(self, components, block_state):
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
)
@staticmethod
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.prepare_latents with self->comp
def prepare_latents(
comp,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
return latents
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
device = components._execution_device
dtype = torch.float32
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.latents = self.prepare_latents(
components,
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
num_channels_latents=components.num_channels_latents,
height=block_state.height,
width=block_state.width,
dtype=dtype,
device=device,
generator=block_state.generator,
latents=block_state.latents,
)
self.set_block_state(state, block_state)
return components, state
class ZImageSetTimestepsStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference. Need to run after prepare latents step."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents", required=True),
InputParam("num_inference_steps", default=9),
InputParam("sigmas"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
),
]
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
latent_height, latent_width = block_state.latents.shape[2], block_state.latents.shape[3]
image_seq_len = (latent_height // 2) * (latent_width // 2) # sequence length after patchify
mu = calculate_shift(
image_seq_len,
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
base_shift=components.scheduler.config.get("base_shift", 0.5),
max_shift=components.scheduler.config.get("max_shift", 1.15),
)
components.scheduler.sigma_min = 0.0
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
device,
sigmas=block_state.sigmas,
mu=mu,
)
self.set_block_state(state, block_state)
return components, state
class ZImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference with strength. Need to run after set timesteps step."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("timesteps", required=True),
InputParam("num_inference_steps", required=True),
InputParam("strength", default=0.6),
]
def check_inputs(self, components, block_state):
if block_state.strength < 0.0 or block_state.strength > 1.0:
raise ValueError(f"Strength must be between 0.0 and 1.0, but got {block_state.strength}")
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
init_timestep = min(block_state.num_inference_steps * block_state.strength, block_state.num_inference_steps)
t_start = int(max(block_state.num_inference_steps - init_timestep, 0))
timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :]
if hasattr(components.scheduler, "set_begin_index"):
components.scheduler.set_begin_index(t_start * components.scheduler.order)
block_state.timesteps = timesteps
block_state.num_inference_steps = block_state.num_inference_steps - t_start
self.set_block_state(state, block_state)
return components, state
class ZImagePrepareLatentswithImageStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return "step that prepares the latents with image condition, need to run after set timesteps and prepare latents step."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents", required=True),
InputParam("image_latents", required=True),
InputParam("timesteps", required=True),
]
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
block_state.latents = components.scheduler.scale_noise(
block_state.image_latents, latent_timestep, block_state.latents
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,91 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ZImageVaeDecoderStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8 * 2}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam(
"latents",
required=True,
),
InputParam(
name="output_type",
default="pil",
type_hint=str,
description="The type of the output images, can be 'pil', 'np', 'pt'",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]],
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_dtype = components.vae.dtype
latents = block_state.latents.to(vae_dtype)
latents = latents / components.vae.config.scaling_factor + components.vae.config.shift_factor
block_state.images = components.vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,310 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import ZImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam
from .modular_pipeline import ZImageModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ZImageLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `ZImageDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
latents = block_state.latents.unsqueeze(2).to(
block_state.dtype
) # [batch_size, num_channels, 1, height, width]
block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width]
timestep = t.expand(latents.shape[0]).to(block_state.dtype)
timestep = (1000 - timestep) / 1000
block_state.timestep = timestep
return components, block_state
class ZImageLoopDenoiser(ModularPipelineBlocks):
model_name = "z-image"
def __init__(
self,
guider_input_fields: Dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")},
):
"""Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image.
Args:
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
(for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either:
- A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds",
"negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
'encoder_hidden_states'.
- A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward
`block_state.image_embeds` for both conditional and unconditional batches.
"""
if not isinstance(guider_input_fields, dict):
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
self._guider_input_fields = guider_input_fields
super().__init__()
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0, "enabled": False}),
default_creation_method="from_config",
),
ComponentSpec("transformer", ZImageTransformer2DModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `ZImageDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
inputs = [
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
guider_input_names = []
uncond_guider_input_names = []
for value in self._guider_input_fields.values():
if isinstance(value, tuple):
guider_input_names.append(value[0])
uncond_guider_input_names.append(value[1])
else:
guider_input_names.append(value)
for name in guider_input_names:
inputs.append(InputParam(name=name, required=True))
for name in uncond_guider_input_names:
inputs.append(InputParam(name=name))
return inputs
@torch.no_grad()
def __call__(
self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
def _convert_dtype(v, dtype):
if isinstance(v, torch.Tensor):
return v.to(dtype)
elif isinstance(v, list):
return [_convert_dtype(t, dtype) for t in v]
return v
cond_kwargs = {
k: _convert_dtype(v, block_state.dtype)
for k, v in cond_kwargs.items()
if k in self._guider_input_fields.keys()
}
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
model_out_list = components.transformer(
x=block_state.latent_model_input,
t=block_state.timestep,
return_dict=False,
**cond_kwargs,
)[0]
noise_pred = torch.stack(model_out_list, dim=0).squeeze(2)
guider_state_batch.noise_pred = -noise_pred
components.guider.cleanup_models(components.transformer)
# Perform guidance
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class ZImageLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "z-image"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `ZImageDenoiseLoopWrapper`)"
)
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# Perform scheduler step using the predicted output
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred.float(),
t,
block_state.latents.float(),
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
class ZImageDenoiseStep(ZImageDenoiseLoopWrapper):
block_classes = [
ZImageLoopBeforeDenoiser,
ZImageLoopDenoiser(
guider_input_fields={
"cap_feats": ("prompt_embeds", "negative_prompt_embeds"),
}
),
ZImageLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `ZImageLoopBeforeDenoiser`\n"
" - `ZImageLoopDenoiser`\n"
" - `ZImageLoopAfterDenoiser`\n"
"This block supports text-to-image and image-to-image tasks for Z-Image."
)

View File

@@ -0,0 +1,344 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL
import torch
from transformers import Qwen2Tokenizer, Qwen3Model
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...utils import is_ftfy_available, logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import ZImageModularPipeline
if is_ftfy_available():
pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_qwen_prompt_embeds(
text_encoder: Qwen3Model,
tokenizer: Qwen2Tokenizer,
prompt: Union[str, List[str]],
device: torch.device,
max_sequence_length: int = 512,
) -> List[torch.Tensor]:
prompt = [prompt] if isinstance(prompt, str) else prompt
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
prompt_embeds_list = []
for i in range(len(prompt_embeds)):
prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]])
return prompt_embeds_list
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def encode_vae_image(
image_tensor: torch.Tensor,
vae: AutoencoderKL,
generator: torch.Generator,
device: torch.device,
dtype: torch.dtype,
latent_channels: int = 16,
):
if not isinstance(image_tensor, torch.Tensor):
raise ValueError(f"Expected image_tensor to be a tensor, got {type(image_tensor)}.")
if isinstance(generator, list) and len(generator) != image_tensor.shape[0]:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image_tensor.shape[0]}."
)
image_tensor = image_tensor.to(device=device, dtype=dtype)
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image_tensor[i : i + 1]), generator=generator[i])
for i in range(image_tensor.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image_tensor), generator=generator)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
class ZImageTextEncoderStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen3Model),
ComponentSpec("tokenizer", Qwen2Tokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0, "enabled": False}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("negative_prompt"),
InputParam("max_sequence_length", default=512),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=List[torch.Tensor],
kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=List[torch.Tensor],
kwargs_type="denoiser_input_fields",
description="negative text embeddings used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@staticmethod
def encode_prompt(
components,
prompt: str,
device: Optional[torch.device] = None,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
max_sequence_length: int = 512,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
max_sequence_length (`int`, defaults to `512`):
The maximum number of text tokens to be used for the generation process.
"""
device = device or components._execution_device
if not isinstance(prompt, list):
prompt = [prompt]
batch_size = len(prompt)
prompt_embeds = get_qwen_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_prompt_embeds = None
if prepare_unconditional_embeds:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = get_qwen_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
return prompt_embeds, negative_prompt_embeds
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
# Encode input prompt
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
) = self.encode_prompt(
components=components,
prompt=block_state.prompt,
device=block_state.device,
prepare_unconditional_embeds=components.requires_unconditional_embeds,
negative_prompt=block_state.negative_prompt,
max_sequence_length=block_state.max_sequence_length,
)
# Add outputs
self.set_block_state(state, block_state)
return components, state
class ZImageVaeImageEncoderStep(ModularPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return "Vae Image Encoder step that generate condition_latents based on image to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8 * 2}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="video latent representation with the first frame image condition",
),
]
@staticmethod
def check_inputs(components, block_state):
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
)
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
image = block_state.image
device = components._execution_device
dtype = torch.float32
vae_dtype = components.vae.dtype
image_tensor = components.image_processor.preprocess(
image, height=block_state.height, width=block_state.width
).to(device=device, dtype=dtype)
block_state.image_latents = encode_vae_image(
image_tensor=image_tensor,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=vae_dtype,
latent_channels=components.num_channels_latents,
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,191 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
ZImageAdditionalInputsStep,
ZImagePrepareLatentsStep,
ZImagePrepareLatentswithImageStep,
ZImageSetTimestepsStep,
ZImageSetTimestepsWithStrengthStep,
ZImageTextInputStep,
)
from .decoders import ZImageVaeDecoderStep
from .denoise import (
ZImageDenoiseStep,
)
from .encoders import (
ZImageTextEncoderStep,
ZImageVaeImageEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# z-image
# text2image
class ZImageCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
ZImageTextInputStep,
ZImagePrepareLatentsStep,
ZImageSetTimestepsStep,
ZImageDenoiseStep,
]
block_names = ["input", "prepare_latents", "set_timesteps", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
)
# z-image: image2image
## denoise
class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
ZImageTextInputStep,
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
ZImagePrepareLatentsStep,
ZImageSetTimestepsStep,
ZImageSetTimestepsWithStrengthStep,
ZImagePrepareLatentswithImageStep,
ZImageDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"prepare_latents",
"set_timesteps",
"set_timesteps_with_strength",
"prepare_latents_with_image",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `ZImageAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
+ " - `ZImageSetTimestepsWithStrengthStep` is used to set the timesteps with strength\n"
+ " - `ZImagePrepareLatentswithImageStep` is used to prepare the latents with image\n"
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
)
## auto blocks
class ZImageAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
ZImageImage2ImageCoreDenoiseStep,
ZImageCoreDenoiseStep,
]
block_names = ["image2image", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image and image2image tasks."
" - `ZImageCoreDenoiseStep` (text2image) for text2image tasks."
" - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks."
+ " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n"
+ " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n"
)
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [ZImageVaeImageEncoderStep]
block_names = ["vae_image_encoder"]
block_trigger_inputs = ["image"]
@property
def description(self) -> str:
return "Vae Image Encoder step that encode the image to generate the image latents"
+"This is an auto pipeline block that works for image2image tasks."
+" - `ZImageVaeImageEncoderStep` is used when `image` is provided."
+" - if `image` is not provided, step will be skipped."
class ZImageAutoBlocks(SequentialPipelineBlocks):
block_classes = [
ZImageTextEncoderStep,
ZImageAutoVaeImageEncoderStep,
ZImageAutoDenoiseStep,
ZImageVaeDecoderStep,
]
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
@property
def description(self) -> str:
return "Auto Modular pipeline for text-to-image and image-to-image using ZImage.\n"
+" - for text-to-image generation, all you need to provide is `prompt`\n"
+" - for image-to-image generation, you need to provide `image`\n"
+" - if `image` is not provided, step will be skipped."
# presets
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", ZImageTextEncoderStep),
("input", ZImageTextInputStep),
("prepare_latents", ZImagePrepareLatentsStep),
("set_timesteps", ZImageSetTimestepsStep),
("denoise", ZImageDenoiseStep),
("decode", ZImageVaeDecoderStep),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", ZImageTextEncoderStep),
("vae_image_encoder", ZImageVaeImageEncoderStep),
("input", ZImageTextInputStep),
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
("prepare_latents", ZImagePrepareLatentsStep),
("set_timesteps", ZImageSetTimestepsStep),
("set_timesteps_with_strength", ZImageSetTimestepsWithStrengthStep),
("prepare_latents_with_image", ZImagePrepareLatentswithImageStep),
("denoise", ZImageDenoiseStep),
("decode", ZImageVaeDecoderStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", ZImageTextEncoderStep),
("vae_image_encoder", ZImageAutoVaeImageEncoderStep),
("denoise", ZImageAutoDenoiseStep),
("decode", ZImageVaeDecoderStep),
]
)
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"image2image": IMAGE2IMAGE_BLOCKS,
"auto": AUTO_BLOCKS,
}

View File

@@ -0,0 +1,72 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...loaders import ZImageLoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ZImageModularPipeline(
ModularPipeline,
ZImageLoraLoaderMixin,
):
"""
A ModularPipeline for Z-Image.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "ZImageAutoBlocks"
@property
def default_height(self):
return 1024
@property
def default_width(self):
return 1024
@property
def vae_scale_factor_spatial(self):
vae_scale_factor_spatial = 16
if hasattr(self, "image_processor") and self.image_processor is not None:
vae_scale_factor_spatial = self.image_processor.config.vae_scale_factor
return vae_scale_factor_spatial
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 16
if hasattr(self, "transformer") and self.transformer is not None:
num_channels_latents = self.transformer.config.in_channels
return num_channels_latents
@property
def requires_unconditional_embeds(self):
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds

View File

@@ -227,6 +227,36 @@ class WanModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ZImageAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ZImageModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AllegroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]