1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add Support for Z-Image Series (#12703)

* Add Support for Z-Image.

* Reformatting with make style, black & isort.

* Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline.

* modified main model forward, freqs_cis left

* refactored to add B dim

* fixed stack issue

* fixed modulation bug

* fixed modulation bug

* fix bug

* remove value_from_time_aware_config

* styling

* Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor.

* Replace padding with pad_sequence; Add gradient checkpointing.

* Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that.

* Fix Docstring and Make Style.

* Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that."

This reverts commit fbf26b7ed1.

* update z-image docstring

* Revert attention dispatcher

* update z-image docstring

* styling

* Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility.

* Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor.

* Remove einop dependency.

* remove redundant imports & make fix-copies

* fix import

---------

Co-authored-by: liudongyang <liudongyang0114@gmail.com>
This commit is contained in:
Jerry Wu
2025-11-25 23:50:00 +08:00
committed by GitHub
parent d33d9f6715
commit 4088e8a851
11 changed files with 1382 additions and 0 deletions

View File

@@ -271,6 +271,7 @@ else:
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"ZImageTransformer2DModel",
"attention_backend",
]
)
@@ -647,6 +648,7 @@ else:
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"ZImagePipeline",
]
)
@@ -1329,6 +1331,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
ZImagePipeline,
)
try:

View File

@@ -111,6 +111,7 @@ def _register_attention_processors_metadata():
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
# AttnProcessor2_0
AttentionProcessorRegistry.register(
@@ -158,6 +159,14 @@ def _register_attention_processors_metadata():
),
)
# ZSingleStreamAttnProcessor
AttentionProcessorRegistry.register(
model_class=ZSingleStreamAttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -179,6 +188,7 @@ def _register_transformer_blocks_metadata():
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
# BasicTransformerBlock
TransformerBlockRegistry.register(
@@ -312,6 +322,15 @@ def _register_transformer_blocks_metadata():
),
)
# ZImage
TransformerBlockRegistry.register(
model_class=ZImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
@@ -338,4 +357,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on

View File

@@ -110,6 +110,7 @@ if is_torch_available():
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -218,6 +219,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
ZImageTransformer2DModel,
)
from .unets import (
I2VGenXLUNet,

View File

@@ -44,3 +44,4 @@ if is_torch_available():
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
from .transformer_z_image import ZImageTransformer2DModel

View File

@@ -0,0 +1,647 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention_processor import Attention
from ...models.modeling_utils import ModelMixin
from ...models.normalization import RMSNorm
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
class TimestepEmbedder(nn.Module):
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
super().__init__()
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size,
mid_size,
bias=True,
),
nn.SiLU(),
nn.Linear(
mid_size,
out_size,
bias=True,
),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
class ZSingleStreamAttnProcessor:
"""
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
original Z-ImageAttention module.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
# Apply Norms
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
# Compute joint attention
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
# Reshape back
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(dtype)
output = attn.to_out[0](hidden_states)
if len(attn.to_out) > 1: # dropout
output = attn.to_out[1](output)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@maybe_allow_in_graph
class ZImageTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
# Refactored to use diffusers Attention with custom processor
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
self.attention = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // n_heads,
heads=n_heads,
qk_norm="rms_norm" if qk_norm else None,
eps=1e-5,
bias=False,
out_bias=False,
processor=ZSingleStreamAttnProcessor(),
)
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + gate_msa * self.attention_norm2(attn_out)
# FFN block
x = x + gate_mlp * self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x) * scale_mlp,
)
)
else:
# Attention block
attn_out = self.attention(
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(attn_out)
# FFN block
x = x + self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x),
)
)
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
)
def forward(self, x, c):
scale = 1.0 + self.adaLN_modulation(c)
x = self.norm_final(x) * scale.unsqueeze(1)
x = self.linear(x)
return x
class RopeEmbedder:
def __init__(
self,
theta: float = 256.0,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (64, 128, 128),
):
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)
return freqs_cis
def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["ZImageTransformerBlock"]
@register_to_config
def __init__(
self,
all_patch_size=(2,),
all_f_patch_size=(1,),
in_channels=16,
dim=3840,
n_layers=30,
n_refiner_layers=2,
n_heads=30,
n_kv_heads=30,
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=2560,
rope_theta=256.0,
t_scale=1000.0,
axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512],
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.all_patch_size = all_patch_size
self.all_f_patch_size = all_f_patch_size
self.dim = dim
self.n_heads = n_heads
self.rope_theta = rope_theta
self.t_scale = t_scale
self.gradient_checkpointing = False
assert len(all_patch_size) == len(all_f_patch_size)
all_x_embedder = {}
all_final_layer = {}
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
self.all_final_layer = nn.ModuleDict(all_final_layer)
self.noise_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
1000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=True,
)
for layer_id in range(n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(cap_feat_dim, dim, bias=True),
)
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
self.layers = nn.ModuleList(
[
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
for layer_id in range(n_layers)
]
)
head_dim = dim // n_heads
assert head_dim == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
pH = pW = patch_size
pF = f_patch_size
bsz = len(x)
assert len(size) == bsz
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
@staticmethod
def create_coordinate_grid(size, start=None, device=None):
if start is None:
start = (0 for _ in size)
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)
def patchify_and_embed(
self,
all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
):
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
all_cap_pos_ids = []
all_cap_pad_mask = []
all_cap_feats_out = []
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
### Process Caption
cap_ori_len = len(cap_feat)
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
# padded position ids
cap_padded_pos_ids = self.create_coordinate_grid(
size=(cap_ori_len + cap_padding_len, 1, 1),
start=(1, 0, 0),
device=device,
).flatten(0, 2)
all_cap_pos_ids.append(cap_padded_pos_ids)
# pad mask
all_cap_pad_mask.append(
torch.cat(
[
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
# padded feature
cap_padded_feat = torch.cat(
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
dim=0,
)
all_cap_feats_out.append(cap_padded_feat)
### Process Image
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padding_pos_ids = (
self.create_coordinate_grid(
size=(1, 1, 1),
start=(0, 0, 0),
device=device,
)
.flatten(0, 2)
.repeat(image_padding_len, 1)
)
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
all_image_pos_ids.append(image_padded_pos_ids)
# pad mask
all_image_pad_mask.append(
torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
# padded feature
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return (
all_image_out,
all_cap_feats_out,
all_image_size,
all_image_pos_ids,
all_cap_pos_ids,
all_image_pad_mask,
all_cap_pad_mask,
)
def forward(
self,
x: List[torch.Tensor],
t,
cap_feats: List[torch.Tensor],
patch_size=2,
f_patch_size=1,
):
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
bsz = len(x)
device = x[0].device
t = t * self.t_scale
t = self.t_embedder(t)
adaln_input = t
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
# x embed & refine
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.noise_refiner:
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
else:
for layer in self.noise_refiner:
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.context_refiner:
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
else:
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
# unified
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.layers:
unified = self._gradient_checkpointing_func(
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
)
else:
for layer in self.layers:
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
return x, {}

View File

@@ -395,6 +395,7 @@ else:
"WanVACEPipeline",
"WanAnimatePipeline",
]
_import_structure["z_image"] = ["ZImagePipeline"]
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
@@ -824,6 +825,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
from .z_image import ZImagePipeline
try:
if not is_onnx_available():

View File

@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa: F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_output"] = ["ZImagePipelineOutput"]
_import_structure["pipeline_z_image"] = ["ZImagePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_output import ZImagePipelineOutput
from .pipeline_z_image import ZImagePipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,35 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from diffusers.utils import BaseOutput
@dataclass
class ZImagePipelineOutput(BaseOutput):
"""
Output class for Z-Image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -0,0 +1,607 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import AutoTokenizer, PreTrainedModel
from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import ZImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import ZImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import ZImagePipeline
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
>>> # (1) Use flash attention 2
>>> # pipe.transformer.set_attention_backend("flash")
>>> # (2) Use flash attention 3
>>> # pipe.transformer.set_attention_backend("_flash_3")
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化一辆复古蒸汽小火车化身为巨大的拉链头正拉开厚厚的冬日积雪展露出一个生机盎然的春天。"
>>> image = pipe(
... prompt,
... height=1024,
... width=1024,
... num_inference_steps=9,
... guidance_scale=0.0,
... generator=torch.Generator("cuda").manual_seed(42),
... ).images[0]
>>> image.save("zimage.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: PreTrainedModel,
tokenizer: AutoTokenizer,
transformer: ZImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = ["" for _ in prompt]
else:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
assert len(prompt) == len(negative_prompt)
negative_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
)
else:
negative_prompt_embeds = []
return prompt_embeds, negative_prompt_embeds
def _encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
assert num_images_per_prompt == 1
device = device or self._execution_device
if prompt_embeds is not None:
return prompt_embeds
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
cfg_normalization: bool = False,
cfg_truncation: float = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 1024):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
cfg_normalization (`bool`, *optional*, defaults to False):
Whether to apply configuration normalization.
cfg_truncation (`float`, *optional*, defaults to 1.0):
The truncation value for configuration.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, *optional*, defaults to 512):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or 1024
width = width or 1024
vae_scale = self.vae_scale_factor * 2
if height % vae_scale != 0:
raise ValueError(
f"Height must be divisible by {vae_scale} (got {height}). "
f"Please adjust the height to a multiple of {vae_scale}."
)
if width % vae_scale != 0:
raise ValueError(
f"Width must be divisible by {vae_scale} (got {width}). "
f"Please adjust the width to a multiple of {vae_scale}."
)
assert self.dtype == torch.bfloat16
dtype = self.dtype
device = self._execution_device
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
self._cfg_normalization = cfg_normalization
self._cfg_truncation = cfg_truncation
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = len(prompt_embeds)
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is not None and prompt is None:
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
raise ValueError(
"When `prompt_embeds` is provided without `prompt`, "
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
)
else:
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
dtype=dtype,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
torch.float32,
device,
generator,
latents,
)
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
# 5. Prepare timesteps
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
self.scheduler.sigma_min = 0.0
scheduler_kwargs = {"mu": mu}
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
timestep = (1000 - timestep) / 1000
# Normalized time for time-aware config (0 at start, 1 at end)
t_norm = timestep[0].item()
# Handle cfg truncation
current_guidance_scale = self.guidance_scale
if (
self.do_classifier_free_guidance
and self._cfg_truncation is not None
and float(self._cfg_truncation) <= 1
):
if t_norm > self._cfg_truncation:
current_guidance_scale = 0.0
# Run CFG only if configured AND scale is non-zero
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
if apply_cfg:
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
model_out_list = self.transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
)[0]
if apply_cfg:
# Perform CFG
pos_out = model_out_list[:batch_size]
neg_out = model_out_list[batch_size:]
noise_pred = []
for j in range(batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()
pred = pos + current_guidance_scale * (pos - neg)
# Renormalization
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(pos)
new_pos_norm = torch.linalg.vector_norm(pred)
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
if new_pos_norm > max_new_norm:
pred = pred * (max_new_norm / new_pos_norm)
noise_pred.append(pred)
noise_pred = torch.stack(noise_pred, dim=0)
else:
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred = noise_pred.squeeze(2)
noise_pred = -noise_pred
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
assert latents.dtype == torch.float32
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
latents = latents.to(dtype)
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ZImagePipelineOutput(images=image)

View File

@@ -3645,3 +3645,18 @@ class WuerstchenPriorPipeline(metaclass=DummyObject):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ZImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

View File