mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Modular]z-image (#12808)
* initiL * up up * fix: z_image -> z-image * style * copy * fix more * some docstring fix
This commit is contained in:
@@ -419,6 +419,8 @@ else:
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanModularPipeline",
|
||||
"ZImageAutoBlocks",
|
||||
"ZImageModularPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["pipelines"].extend(
|
||||
@@ -1124,6 +1126,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanModularPipeline,
|
||||
ZImageAutoBlocks,
|
||||
ZImageModularPipeline,
|
||||
)
|
||||
from .pipelines import (
|
||||
AllegroPipeline,
|
||||
|
||||
@@ -60,6 +60,10 @@ else:
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
]
|
||||
_import_structure["z_image"] = [
|
||||
"ZImageAutoBlocks",
|
||||
"ZImageModularPipeline",
|
||||
]
|
||||
_import_structure["components_manager"] = ["ComponentsManager"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -91,6 +95,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
||||
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
("z-image", "ZImageModularPipeline"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -530,6 +530,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
vae_dtype = components.vae.dtype
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
@@ -555,7 +556,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=vae_dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
@@ -627,6 +628,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
vae_dtype = components.vae.dtype
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
@@ -659,7 +661,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=vae_dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
|
||||
57
src/diffusers/modular_pipelines/z_image/__init__.py
Normal file
57
src/diffusers/modular_pipelines/z_image/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["decoders"] = ["ZImageVaeDecoderStep"]
|
||||
_import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"ZImageAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["ZImageModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .decoders import ZImageVaeDecoderStep
|
||||
from .encoders import ZImageTextEncoderStep
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
ZImageAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import ZImageModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
621
src/diffusers/modular_pipelines/z_image/before_denoise.py
Normal file
621
src/diffusers/modular_pipelines/z_image/before_denoise.py
Normal file
@@ -0,0 +1,621 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import ZImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import ZImageModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
|
||||
# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
|
||||
# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
|
||||
# configuration of guider is.
|
||||
|
||||
|
||||
def repeat_tensor_to_batch_size(
|
||||
input_name: str,
|
||||
input_tensor: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_images_per_prompt: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Repeat tensor elements to match the final batch size.
|
||||
|
||||
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
|
||||
by repeating each element along dimension 0.
|
||||
|
||||
The input tensor must have batch size 1 or batch_size. The function will:
|
||||
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
|
||||
- If batch size equals batch_size: repeat each element num_images_per_prompt times
|
||||
|
||||
Args:
|
||||
input_name (str): Name of the input tensor (used for error messages)
|
||||
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
|
||||
batch_size (int): The base batch size (number of prompts)
|
||||
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
|
||||
|
||||
Raises:
|
||||
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
|
||||
|
||||
Examples:
|
||||
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
|
||||
batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
|
||||
[4, 3]
|
||||
|
||||
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
|
||||
tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
|
||||
- shape: [4, 3]
|
||||
"""
|
||||
# make sure input is a tensor
|
||||
if not isinstance(input_tensor, torch.Tensor):
|
||||
raise ValueError(f"`{input_name}` must be a tensor")
|
||||
|
||||
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
|
||||
if input_tensor.shape[0] == 1:
|
||||
repeat_by = batch_size * num_images_per_prompt
|
||||
elif input_tensor.shape[0] == batch_size:
|
||||
repeat_by = num_images_per_prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
|
||||
)
|
||||
|
||||
# expand the tensor to match the batch_size * num_images_per_prompt
|
||||
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_spatial: int) -> Tuple[int, int]:
|
||||
"""Calculate image dimensions from latent tensor dimensions.
|
||||
|
||||
This function converts latent spatial dimensions to image spatial dimensions by multiplying the latent height/width
|
||||
by the VAE scale factor.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor): The latent tensor. Must have 4 dimensions.
|
||||
Expected shapes: [batch, channels, height, width]
|
||||
vae_scale_factor (int): The scale factor used by the VAE to compress image spatial dimension.
|
||||
By default, it is 16
|
||||
Returns:
|
||||
Tuple[int, int]: The calculated image dimensions as (height, width)
|
||||
"""
|
||||
latent_height, latent_width = latents.shape[2:]
|
||||
height = latent_height * vae_scale_factor_spatial // 2
|
||||
width = latent_width * vae_scale_factor_spatial // 2
|
||||
|
||||
return height, width
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Input processing step that:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
|
||||
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
|
||||
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
|
||||
"have a final batch_size of batch_size * num_images_per_prompt."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", ZImageTransformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `transformer.dtype`)",
|
||||
),
|
||||
]
|
||||
|
||||
def check_inputs(self, components, block_state):
|
||||
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
|
||||
if not isinstance(block_state.prompt_embeds, list):
|
||||
raise ValueError(
|
||||
f"`prompt_embeds` must be a list when passed directly, but got {type(block_state.prompt_embeds)}."
|
||||
)
|
||||
if not isinstance(block_state.negative_prompt_embeds, list):
|
||||
raise ValueError(
|
||||
f"`negative_prompt_embeds` must be a list when passed directly, but got {type(block_state.negative_prompt_embeds)}."
|
||||
)
|
||||
if len(block_state.prompt_embeds) != len(block_state.negative_prompt_embeds):
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same length when passed directly, but"
|
||||
f" got: `prompt_embeds` {len(block_state.prompt_embeds)} != `negative_prompt_embeds`"
|
||||
f" {len(block_state.negative_prompt_embeds)}."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.batch_size = len(block_state.prompt_embeds)
|
||||
block_state.dtype = block_state.prompt_embeds[0].dtype
|
||||
|
||||
if block_state.num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in block_state.prompt_embeds for _ in range(block_state.num_images_per_prompt)]
|
||||
block_state.prompt_embeds = prompt_embeds
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = [
|
||||
npe for npe in block_state.negative_prompt_embeds for _ in range(block_state.num_images_per_prompt)
|
||||
]
|
||||
block_state.negative_prompt_embeds = negative_prompt_embeds
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class ZImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
1. For encoded image latents, use it update height/width if None, and expands batch size
|
||||
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||
|
||||
This is a dynamic block that allows you to configure which inputs to process.
|
||||
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
|
||||
a single string or list of strings. Defaults to ["image_latents"].
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to [].
|
||||
|
||||
Examples:
|
||||
# Configure to process image_latents (default behavior) ZImageAdditionalInputsStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents", "control_image_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs ZImageAdditionalInputsStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["image_embeds"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate num_frames, height/width from latents
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_spatial)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ZImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
)
|
||||
]
|
||||
|
||||
def check_inputs(self, components, block_state):
|
||||
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.prepare_latents with self->comp
|
||||
def prepare_latents(
|
||||
comp,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
|
||||
block_state.latents = self.prepare_latents(
|
||||
components,
|
||||
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||
num_channels_latents=components.num_channels_latents,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=block_state.generator,
|
||||
latents=block_state.latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class ZImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for inference. Need to run after prepare latents step."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True),
|
||||
InputParam("num_inference_steps", default=9),
|
||||
InputParam("sigmas"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
latent_height, latent_width = block_state.latents.shape[2], block_state.latents.shape[3]
|
||||
image_seq_len = (latent_height // 2) * (latent_width // 2) # sequence length after patchify
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
|
||||
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
base_shift=components.scheduler.config.get("base_shift", 0.5),
|
||||
max_shift=components.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
components.scheduler.sigma_min = 0.0
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
device,
|
||||
sigmas=block_state.sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ZImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for inference with strength. Need to run after set timesteps step."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("timesteps", required=True),
|
||||
InputParam("num_inference_steps", required=True),
|
||||
InputParam("strength", default=0.6),
|
||||
]
|
||||
|
||||
def check_inputs(self, components, block_state):
|
||||
if block_state.strength < 0.0 or block_state.strength > 1.0:
|
||||
raise ValueError(f"Strength must be between 0.0 and 1.0, but got {block_state.strength}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
init_timestep = min(block_state.num_inference_steps * block_state.strength, block_state.num_inference_steps)
|
||||
|
||||
t_start = int(max(block_state.num_inference_steps - init_timestep, 0))
|
||||
timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :]
|
||||
if hasattr(components.scheduler, "set_begin_index"):
|
||||
components.scheduler.set_begin_index(t_start * components.scheduler.order)
|
||||
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = block_state.num_inference_steps - t_start
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ZImagePrepareLatentswithImageStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the latents with image condition, need to run after set timesteps and prepare latents step."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True),
|
||||
InputParam("image_latents", required=True),
|
||||
InputParam("timesteps", required=True),
|
||||
]
|
||||
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
|
||||
block_state.latents = components.scheduler.scale_noise(
|
||||
block_state.image_latents, latent_timestep, block_state.latents
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
91
src/diffusers/modular_pipelines/z_image/decoders.py
Normal file
91
src/diffusers/modular_pipelines/z_image/decoders.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ZImageVaeDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8 * 2}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
),
|
||||
InputParam(
|
||||
name="output_type",
|
||||
default="pil",
|
||||
type_hint=str,
|
||||
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]],
|
||||
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
vae_dtype = components.vae.dtype
|
||||
|
||||
latents = block_state.latents.to(vae_dtype)
|
||||
latents = latents / components.vae.config.scaling_factor + components.vae.config.shift_factor
|
||||
|
||||
block_state.images = components.vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
310
src/diffusers/modular_pipelines/z_image/denoise.py
Normal file
310
src/diffusers/modular_pipelines/z_image/denoise.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import ZImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam
|
||||
from .modular_pipeline import ZImageModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ZImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `ZImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
latents = block_state.latents.unsqueeze(2).to(
|
||||
block_state.dtype
|
||||
) # [batch_size, num_channels, 1, height, width]
|
||||
block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width]
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(block_state.dtype)
|
||||
timestep = (1000 - timestep) / 1000
|
||||
block_state.timestep = timestep
|
||||
return components, block_state
|
||||
|
||||
|
||||
class ZImageLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: Dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")},
|
||||
):
|
||||
"""Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image.
|
||||
|
||||
Args:
|
||||
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
|
||||
(for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either:
|
||||
|
||||
- A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds",
|
||||
"negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and
|
||||
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
|
||||
'encoder_hidden_states'.
|
||||
- A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward
|
||||
`block_state.image_embeds` for both conditional and unconditional batches.
|
||||
"""
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 5.0, "enabled": False}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", ZImageTransformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoise the latents with guidance. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `ZImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
inputs = [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
guider_input_names = []
|
||||
uncond_guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.append(value[0])
|
||||
uncond_guider_input_names.append(value[1])
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
|
||||
for name in guider_input_names:
|
||||
inputs.append(InputParam(name=name, required=True))
|
||||
for name in uncond_guider_input_names:
|
||||
inputs.append(InputParam(name=name))
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||
# you will get a guider_state with two batches:
|
||||
# guider_state = [
|
||||
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||
# ]
|
||||
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
|
||||
# run the denoiser for each guidance batch
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
|
||||
def _convert_dtype(v, dtype):
|
||||
if isinstance(v, torch.Tensor):
|
||||
return v.to(dtype)
|
||||
elif isinstance(v, list):
|
||||
return [_convert_dtype(t, dtype) for t in v]
|
||||
return v
|
||||
|
||||
cond_kwargs = {
|
||||
k: _convert_dtype(v, block_state.dtype)
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
|
||||
# Predict the noise residual
|
||||
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||
model_out_list = components.transformer(
|
||||
x=block_state.latent_model_input,
|
||||
t=block_state.timestep,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
noise_pred = torch.stack(model_out_list, dim=0).squeeze(2)
|
||||
guider_state_batch.noise_pred = -noise_pred
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
# Perform guidance
|
||||
block_state.noise_pred = components.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class ZImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that update the latents. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `ZImageDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
# Perform scheduler step using the predicted output
|
||||
latents_dtype = block_state.latents.dtype
|
||||
block_state.latents = components.scheduler.step(
|
||||
block_state.noise_pred.float(),
|
||||
t,
|
||||
block_state.latents.float(),
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if block_state.latents.dtype != latents_dtype:
|
||||
block_state.latents = block_state.latents.to(latents_dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pipeline block that iteratively denoise the latents over `timesteps`. "
|
||||
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
|
||||
)
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class ZImageDenoiseStep(ZImageDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
ZImageLoopBeforeDenoiser,
|
||||
ZImageLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"cap_feats": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
}
|
||||
),
|
||||
ZImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `ZImageLoopBeforeDenoiser`\n"
|
||||
" - `ZImageLoopDenoiser`\n"
|
||||
" - `ZImageLoopAfterDenoiser`\n"
|
||||
"This block supports text-to-image and image-to-image tasks for Z-Image."
|
||||
)
|
||||
344
src/diffusers/modular_pipelines/z_image/encoders.py
Normal file
344
src/diffusers/modular_pipelines/z_image/encoders.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen3Model
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import ZImageModularPipeline
|
||||
|
||||
|
||||
if is_ftfy_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_qwen_prompt_embeds(
|
||||
text_encoder: Qwen3Model,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.Tensor]:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
prompt_embeds_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return prompt_embeds_list
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def encode_vae_image(
|
||||
image_tensor: torch.Tensor,
|
||||
vae: AutoencoderKL,
|
||||
generator: torch.Generator,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
latent_channels: int = 16,
|
||||
):
|
||||
if not isinstance(image_tensor, torch.Tensor):
|
||||
raise ValueError(f"Expected image_tensor to be a tensor, got {type(image_tensor)}.")
|
||||
|
||||
if isinstance(generator, list) and len(generator) != image_tensor.shape[0]:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image_tensor.shape[0]}."
|
||||
)
|
||||
|
||||
image_tensor = image_tensor.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image_tensor[i : i + 1]), generator=generator[i])
|
||||
for i in range(image_tensor.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image_tensor), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
class ZImageTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generate text_embeddings to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3Model),
|
||||
ComponentSpec("tokenizer", Qwen2Tokenizer),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 5.0, "enabled": False}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("max_sequence_length", default=512),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="negative text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
if block_state.prompt is not None and (
|
||||
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
||||
):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@staticmethod
|
||||
def encode_prompt(
|
||||
components,
|
||||
prompt: str,
|
||||
device: Optional[torch.device] = None,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
prepare_unconditional_embeds (`bool`):
|
||||
whether to use prepare unconditional embeddings or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of text tokens to be used for the generation process.
|
||||
"""
|
||||
device = device or components._execution_device
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt_embeds = get_qwen_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = None
|
||||
if prepare_unconditional_embeds:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = get_qwen_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
# Get inputs and intermediates
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# Encode input prompt
|
||||
(
|
||||
block_state.prompt_embeds,
|
||||
block_state.negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
components=components,
|
||||
prompt=block_state.prompt,
|
||||
device=block_state.device,
|
||||
prepare_unconditional_embeds=components.requires_unconditional_embeds,
|
||||
negative_prompt=block_state.negative_prompt,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
)
|
||||
|
||||
# Add outputs
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ZImageVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "z-image"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Image Encoder step that generate condition_latents based on image to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8 * 2}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="video latent representation with the first frame image condition",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||
)
|
||||
|
||||
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
image = block_state.image
|
||||
|
||||
device = components._execution_device
|
||||
dtype = torch.float32
|
||||
vae_dtype = components.vae.dtype
|
||||
|
||||
image_tensor = components.image_processor.preprocess(
|
||||
image, height=block_state.height, width=block_state.width
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
block_state.image_latents = encode_vae_image(
|
||||
image_tensor=image_tensor,
|
||||
vae=components.vae,
|
||||
generator=block_state.generator,
|
||||
device=device,
|
||||
dtype=vae_dtype,
|
||||
latent_channels=components.num_channels_latents,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
191
src/diffusers/modular_pipelines/z_image/modular_blocks.py
Normal file
191
src/diffusers/modular_pipelines/z_image/modular_blocks.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
ZImageAdditionalInputsStep,
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImagePrepareLatentswithImageStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageSetTimestepsWithStrengthStep,
|
||||
ZImageTextInputStep,
|
||||
)
|
||||
from .decoders import ZImageVaeDecoderStep
|
||||
from .denoise import (
|
||||
ZImageDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
ZImageTextEncoderStep,
|
||||
ZImageVaeImageEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# z-image
|
||||
# text2image
|
||||
class ZImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
ZImageTextInputStep,
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "prepare_latents", "set_timesteps", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# z-image: image2image
|
||||
## denoise
|
||||
class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
ZImageTextInputStep,
|
||||
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageSetTimestepsWithStrengthStep,
|
||||
ZImagePrepareLatentswithImageStep,
|
||||
ZImageDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"set_timesteps_with_strength",
|
||||
"prepare_latents_with_image",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `ZImageAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `ZImageSetTimestepsWithStrengthStep` is used to set the timesteps with strength\n"
|
||||
+ " - `ZImagePrepareLatentswithImageStep` is used to prepare the latents with image\n"
|
||||
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
## auto blocks
|
||||
class ZImageAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
ZImageImage2ImageCoreDenoiseStep,
|
||||
ZImageCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["image2image", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2image and image2image tasks."
|
||||
" - `ZImageCoreDenoiseStep` (text2image) for text2image tasks."
|
||||
" - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks."
|
||||
+ " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n"
|
||||
+ " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [ZImageVaeImageEncoderStep]
|
||||
block_names = ["vae_image_encoder"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+"This is an auto pipeline block that works for image2image tasks."
|
||||
+" - `ZImageVaeImageEncoderStep` is used when `image` is provided."
|
||||
+" - if `image` is not provided, step will be skipped."
|
||||
|
||||
|
||||
class ZImageAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
ZImageTextEncoderStep,
|
||||
ZImageAutoVaeImageEncoderStep,
|
||||
ZImageAutoDenoiseStep,
|
||||
ZImageVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Auto Modular pipeline for text-to-image and image-to-image using ZImage.\n"
|
||||
+" - for text-to-image generation, all you need to provide is `prompt`\n"
|
||||
+" - for image-to-image generation, you need to provide `image`\n"
|
||||
+" - if `image` is not provided, step will be skipped."
|
||||
|
||||
|
||||
# presets
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("input", ZImageTextInputStep),
|
||||
("prepare_latents", ZImagePrepareLatentsStep),
|
||||
("set_timesteps", ZImageSetTimestepsStep),
|
||||
("denoise", ZImageDenoiseStep),
|
||||
("decode", ZImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("vae_image_encoder", ZImageVaeImageEncoderStep),
|
||||
("input", ZImageTextInputStep),
|
||||
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
|
||||
("prepare_latents", ZImagePrepareLatentsStep),
|
||||
("set_timesteps", ZImageSetTimestepsStep),
|
||||
("set_timesteps_with_strength", ZImageSetTimestepsWithStrengthStep),
|
||||
("prepare_latents_with_image", ZImagePrepareLatentswithImageStep),
|
||||
("denoise", ZImageDenoiseStep),
|
||||
("decode", ZImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("vae_image_encoder", ZImageAutoVaeImageEncoderStep),
|
||||
("denoise", ZImageAutoDenoiseStep),
|
||||
("decode", ZImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"image2image": IMAGE2IMAGE_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
}
|
||||
72
src/diffusers/modular_pipelines/z_image/modular_pipeline.py
Normal file
72
src/diffusers/modular_pipelines/z_image/modular_pipeline.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...loaders import ZImageLoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ZImageModularPipeline(
|
||||
ModularPipeline,
|
||||
ZImageLoraLoaderMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for Z-Image.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "ZImageAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def vae_scale_factor_spatial(self):
|
||||
vae_scale_factor_spatial = 16
|
||||
if hasattr(self, "image_processor") and self.image_processor is not None:
|
||||
vae_scale_factor_spatial = self.image_processor.config.vae_scale_factor
|
||||
return vae_scale_factor_spatial
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 16
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
@@ -227,6 +227,36 @@ class WanModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AllegroPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user